{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
module Language.Halide.Func
(
Func (..)
, FuncTy (..)
, Stage (..)
, buffer
, scalar
, define
, (!)
, realize
, Schedulable (..)
, TailStrategy (..)
, computeRoot
, getStage
, getLoopLevel
, getLoopLevelAtStage
, asUsed
, asUsedBy
, copyToDevice
, copyToHost
, storeAt
, computeAt
, dim
, estimate
, bound
, getArgs
, update
, hasUpdateDefinitions
, getUpdateStage
, prettyLoopNest
, IndexTuple
, asBufferParam
, withFunc
, withBufferParam
, wrapCxxFunc
, CxxStage
, wrapCxxStage
, withCxxStage
)
where
import Control.Exception (bracket)
import Control.Monad (forM)
import Data.IORef
import Data.Kind (Type)
import Data.Proxy
import Data.Text (Text)
import Data.Text.Encoding qualified as T
import Foreign.ForeignPtr
import Foreign.Marshal (toBool, with)
import Foreign.Ptr (Ptr, castPtr)
import GHC.Stack (HasCallStack)
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Dimension
import Language.Halide.Expr
import Language.Halide.LoopLevel
import Language.Halide.Target
import Language.Halide.Type
import Language.Halide.Utils
import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding (min, tail)
data CxxStage
importHalide
data Func (t :: FuncTy) (n :: Nat) (a :: Type) where
Func :: {-# UNPACK #-} !(ForeignPtr CxxFunc) -> Func 'FuncTy n a
Param :: {-# UNPACK #-} !(IORef (Maybe (ForeignPtr CxxImageParam))) -> Func 'ParamTy n a
data FuncTy = FuncTy | ParamTy
deriving stock (Int -> FuncTy -> ShowS
[FuncTy] -> ShowS
FuncTy -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FuncTy] -> ShowS
$cshowList :: [FuncTy] -> ShowS
show :: FuncTy -> String
$cshow :: FuncTy -> String
showsPrec :: Int -> FuncTy -> ShowS
$cshowsPrec :: Int -> FuncTy -> ShowS
Show, FuncTy -> FuncTy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FuncTy -> FuncTy -> Bool
$c/= :: FuncTy -> FuncTy -> Bool
== :: FuncTy -> FuncTy -> Bool
$c== :: FuncTy -> FuncTy -> Bool
Eq, Eq FuncTy
FuncTy -> FuncTy -> Bool
FuncTy -> FuncTy -> Ordering
FuncTy -> FuncTy -> FuncTy
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: FuncTy -> FuncTy -> FuncTy
$cmin :: FuncTy -> FuncTy -> FuncTy
max :: FuncTy -> FuncTy -> FuncTy
$cmax :: FuncTy -> FuncTy -> FuncTy
>= :: FuncTy -> FuncTy -> Bool
$c>= :: FuncTy -> FuncTy -> Bool
> :: FuncTy -> FuncTy -> Bool
$c> :: FuncTy -> FuncTy -> Bool
<= :: FuncTy -> FuncTy -> Bool
$c<= :: FuncTy -> FuncTy -> Bool
< :: FuncTy -> FuncTy -> Bool
$c< :: FuncTy -> FuncTy -> Bool
compare :: FuncTy -> FuncTy -> Ordering
$ccompare :: FuncTy -> FuncTy -> Ordering
Ord)
newtype Stage (n :: Nat) (a :: Type) = Stage (ForeignPtr CxxStage)
data TailStrategy
=
TailRoundUp
|
TailGuardWithIf
|
TailPredicate
|
TailPredicateLoads
|
TailPredicateStores
|
TailShiftInwards
|
TailAuto
deriving stock (TailStrategy -> TailStrategy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TailStrategy -> TailStrategy -> Bool
$c/= :: TailStrategy -> TailStrategy -> Bool
== :: TailStrategy -> TailStrategy -> Bool
$c== :: TailStrategy -> TailStrategy -> Bool
Eq, Eq TailStrategy
TailStrategy -> TailStrategy -> Bool
TailStrategy -> TailStrategy -> Ordering
TailStrategy -> TailStrategy -> TailStrategy
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TailStrategy -> TailStrategy -> TailStrategy
$cmin :: TailStrategy -> TailStrategy -> TailStrategy
max :: TailStrategy -> TailStrategy -> TailStrategy
$cmax :: TailStrategy -> TailStrategy -> TailStrategy
>= :: TailStrategy -> TailStrategy -> Bool
$c>= :: TailStrategy -> TailStrategy -> Bool
> :: TailStrategy -> TailStrategy -> Bool
$c> :: TailStrategy -> TailStrategy -> Bool
<= :: TailStrategy -> TailStrategy -> Bool
$c<= :: TailStrategy -> TailStrategy -> Bool
< :: TailStrategy -> TailStrategy -> Bool
$c< :: TailStrategy -> TailStrategy -> Bool
compare :: TailStrategy -> TailStrategy -> Ordering
$ccompare :: TailStrategy -> TailStrategy -> Ordering
Ord, Int -> TailStrategy -> ShowS
[TailStrategy] -> ShowS
TailStrategy -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TailStrategy] -> ShowS
$cshowList :: [TailStrategy] -> ShowS
show :: TailStrategy -> String
$cshow :: TailStrategy -> String
showsPrec :: Int -> TailStrategy -> ShowS
$cshowsPrec :: Int -> TailStrategy -> ShowS
Show)
type IndexTuple i ts = (IsTuple (Arguments ts) i, All ((~) (Expr Int32)) ts)
class (KnownNat n, IsHalideType a) => Schedulable f n a where
vectorize :: VarOrRVar -> f n a -> IO (f n a)
unroll :: VarOrRVar -> f n a -> IO (f n a)
reorder :: [VarOrRVar] -> f n a -> IO (f n a)
split :: TailStrategy -> VarOrRVar -> (VarOrRVar, VarOrRVar) -> Expr Int32 -> f n a -> IO (f n a)
fuse :: (VarOrRVar, VarOrRVar) -> VarOrRVar -> f n a -> IO (f n a)
serial :: VarOrRVar -> f n a -> IO (f n a)
parallel :: VarOrRVar -> f n a -> IO (f n a)
specialize :: Expr Bool -> f n a -> IO (Stage n a)
specializeFail :: Text -> f n a -> IO ()
gpuBlocks :: (IndexTuple i ts, 1 <= Length ts, Length ts <= 3) => DeviceAPI -> i -> f n a -> IO (f n a)
gpuThreads :: (IndexTuple i ts, 1 <= Length ts, Length ts <= 3) => DeviceAPI -> i -> f n a -> IO (f n a)
gpuLanes :: DeviceAPI -> VarOrRVar -> f n a -> IO (f n a)
computeWith :: LoopAlignStrategy -> f n a -> LoopLevel t -> IO ()
instance (KnownNat n, IsHalideType a) => Schedulable Stage n a where
vectorize :: VarOrRVar -> Stage n a -> IO (Stage n a)
vectorize VarOrRVar
var Stage n a
stage = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->vectorize(*$(const Halide::VarOrRVar* var'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
unroll :: VarOrRVar -> Stage n a -> IO (Stage n a)
unroll VarOrRVar
var Stage n a
stage = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->unroll(*$(const Halide::VarOrRVar* var'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
reorder :: [VarOrRVar] -> Stage n a -> IO (Stage n a)
reorder [VarOrRVar]
args Stage n a
stage = do
forall k t a.
HasCxxVector k =>
(t -> (Ptr k -> IO a) -> IO a)
-> [t] -> (Ptr (CxxVector k) -> IO a) -> IO a
withMany forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar [VarOrRVar]
args forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
args' -> do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
[C.throwBlock| void {
handle_halide_exceptions([=]() {
$(Halide::Stage* stage')->reorder(
*$(const std::vector<Halide::VarOrRVar>* args'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
split :: TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> Stage n a
-> IO (Stage n a)
split TailStrategy
tail VarOrRVar
old (VarOrRVar
outer, VarOrRVar
inner) VarOrRVar
factor Stage n a
stage = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
old forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
old' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
outer forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
outer' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
inner forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
inner' ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
factor forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
factor' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->split(
*$(const Halide::VarOrRVar* old'),
*$(const Halide::VarOrRVar* outer'),
*$(const Halide::VarOrRVar* inner'),
*$(const Halide::Expr* factor'),
static_cast<Halide::TailStrategy>($(int t)));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
where
t :: CInt
t = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ TailStrategy
tail
fuse :: (VarOrRVar, VarOrRVar) -> VarOrRVar -> Stage n a -> IO (Stage n a)
fuse (VarOrRVar
outer, VarOrRVar
inner) VarOrRVar
fused Stage n a
stage = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
outer forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
outer' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
inner forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
inner' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
fused forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
fused' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->fuse(
*$(const Halide::VarOrRVar* outer'),
*$(const Halide::VarOrRVar* inner'),
*$(const Halide::VarOrRVar* fused'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
serial :: VarOrRVar -> Stage n a -> IO (Stage n a)
serial VarOrRVar
var Stage n a
stage = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->serial(*$(const Halide::VarOrRVar* var'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
parallel :: VarOrRVar -> Stage n a -> IO (Stage n a)
parallel VarOrRVar
var Stage n a
stage = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->parallel(*$(const Halide::VarOrRVar* var'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
specialize :: Expr Bool -> Stage n a -> IO (Stage n a)
specialize Expr Bool
cond Stage n a
stage = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Bool
cond forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
cond' ->
forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Ptr CxxStage -> IO (Stage n a)
wrapCxxStage
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::Stage* {
return handle_halide_exceptions([=](){
return new Halide::Stage{$(Halide::Stage* stage')->specialize(
*$(const Halide::Expr* cond'))};
});
} |]
specializeFail :: Text -> Stage n a -> IO ()
specializeFail (Text -> ByteString
T.encodeUtf8 -> ByteString
s) Stage n a
stage =
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
[C.throwBlock| void {
return handle_halide_exceptions([=](){
$(Halide::Stage* stage')->specialize_fail(
std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)});
});
} |]
gpuBlocks :: forall i (ts :: [*]).
(IndexTuple i ts, 1 <= Length ts, Length ts <= 3) =>
DeviceAPI -> i -> Stage n a -> IO (Stage n a)
gpuBlocks (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) i
vars Stage n a
stage = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar (forall a t. IsTuple a t => t -> a
fromTuple i
vars) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
vars' -> do
[C.throwBlock| void {
handle_halide_exceptions([=](){
auto const& vars = *$(const std::vector<Halide::VarOrRVar>* vars');
auto& stage = *$(Halide::Stage* stage');
auto const device = static_cast<Halide::DeviceAPI>($(int api));
switch (vars.size()) {
case 1: stage.gpu_blocks(vars.at(0), device);
break;
case 2: stage.gpu_blocks(vars.at(0), vars.at(1), device);
break;
case 3: stage.gpu_blocks(vars.at(0), vars.at(1), vars.at(2), device);
break;
default: throw std::runtime_error{"unexpected number of arguments in gpuBlocks"};
}
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
gpuThreads :: forall i (ts :: [*]).
(IndexTuple i ts, 1 <= Length ts, Length ts <= 3) =>
DeviceAPI -> i -> Stage n a -> IO (Stage n a)
gpuThreads (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) i
vars Stage n a
stage = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar (forall a t. IsTuple a t => t -> a
fromTuple i
vars) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
vars' -> do
[C.throwBlock| void {
handle_halide_exceptions([=](){
auto const& vars = *$(const std::vector<Halide::VarOrRVar>* vars');
auto& stage = *$(Halide::Stage* stage');
auto const device = static_cast<Halide::DeviceAPI>($(int api));
switch (vars.size()) {
case 1: stage.gpu_threads(vars.at(0), device);
break;
case 2: stage.gpu_threads(vars.at(0), vars.at(1), device);
break;
case 3: stage.gpu_threads(vars.at(0), vars.at(1), vars.at(2), device);
break;
default: throw std::runtime_error{"unexpected number of arguments in gpuThreads"};
}
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
gpuLanes :: DeviceAPI -> VarOrRVar -> Stage n a -> IO (Stage n a)
gpuLanes (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) VarOrRVar
var Stage n a
stage = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->gpu_lanes(
*$(const Halide::VarOrRVar* var'),
static_cast<Halide::DeviceAPI>($(int api)));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
computeWith :: forall (t :: LoopLevelTy).
LoopAlignStrategy -> Stage n a -> LoopLevel t -> IO ()
computeWith (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
align) Stage n a
stage LoopLevel t
level = do
forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
level' ->
[C.throwBlock| void {
handle_halide_exceptions([=]() {
$(Halide::Stage* stage')->compute_with(
*$(const Halide::LoopLevel* level'),
static_cast<Halide::LoopAlignStrategy>($(int align)));
});
} |]
viaStage1
:: (KnownNat n, IsHalideType b)
=> (a -> Stage n b -> IO (Stage n b))
-> a
-> Func t n b
-> IO (Func t n b)
viaStage1 :: forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 a -> Stage n b -> IO (Stage n b)
f a
a1 Func t n b
func = do
Stage n b
_ <- a -> Stage n b -> IO (Stage n b)
f a
a1 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func
viaStage2
:: (KnownNat n, IsHalideType b)
=> (a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1
-> a2
-> Func t n b
-> IO (Func t n b)
viaStage2 :: forall (n :: Nat) b a1 a2 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 a1 -> a2 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 Func t n b
func = do
Stage n b
_ <- a1 -> a2 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func
viaStage4
:: (KnownNat n, IsHalideType b)
=> (a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
-> a1
-> a2
-> a3
-> a4
-> Func t n b
-> IO (Func t n b)
viaStage4 :: forall (n :: Nat) b a1 a2 a3 a4 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> a3 -> a4 -> Func t n b -> IO (Func t n b)
viaStage4 a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 a3
a3 a4
a4 Func t n b
func = do
Stage n b
_ <- a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 a3
a3 a4
a4 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func
instance (KnownNat n, IsHalideType a) => Schedulable (Func t) n a where
vectorize :: VarOrRVar -> Func t n a -> IO (Func t n a)
vectorize = forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
vectorize
unroll :: VarOrRVar -> Func t n a -> IO (Func t n a)
unroll = forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
unroll
reorder :: [VarOrRVar] -> Func t n a -> IO (Func t n a)
reorder = forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
[VarOrRVar] -> f n a -> IO (f n a)
reorder
split :: TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> Func t n a
-> IO (Func t n a)
split = forall (n :: Nat) b a1 a2 a3 a4 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> a3 -> a4 -> Func t n b -> IO (Func t n b)
viaStage4 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> f n a
-> IO (f n a)
split
fuse :: (VarOrRVar, VarOrRVar)
-> VarOrRVar -> Func t n a -> IO (Func t n a)
fuse = forall (n :: Nat) b a1 a2 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
(VarOrRVar, VarOrRVar) -> VarOrRVar -> f n a -> IO (f n a)
fuse
serial :: VarOrRVar -> Func t n a -> IO (Func t n a)
serial = forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
serial
parallel :: VarOrRVar -> Func t n a -> IO (Func t n a)
parallel = forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
parallel
specialize :: Expr Bool -> Func t n a -> IO (Stage n a)
specialize Expr Bool
cond Func t n a
func = forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Expr Bool -> f n a -> IO (Stage n a)
specialize Expr Bool
cond
specializeFail :: Text -> Func t n a -> IO ()
specializeFail Text
msg Func t n a
func = forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Text -> f n a -> IO ()
specializeFail Text
msg
gpuBlocks :: forall i (ts :: [*]).
(IndexTuple i ts, 1 <= Length ts, Length ts <= 3) =>
DeviceAPI -> i -> Func t n a -> IO (Func t n a)
gpuBlocks = forall (n :: Nat) b a1 a2 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a i (ts :: [*]).
(Schedulable f n a, IndexTuple i ts, 1 <= Length ts,
Length ts <= 3) =>
DeviceAPI -> i -> f n a -> IO (f n a)
gpuBlocks
gpuThreads :: forall i (ts :: [*]).
(IndexTuple i ts, 1 <= Length ts, Length ts <= 3) =>
DeviceAPI -> i -> Func t n a -> IO (Func t n a)
gpuThreads = forall (n :: Nat) b a1 a2 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a i (ts :: [*]).
(Schedulable f n a, IndexTuple i ts, 1 <= Length ts,
Length ts <= 3) =>
DeviceAPI -> i -> f n a -> IO (f n a)
gpuThreads
gpuLanes :: DeviceAPI -> VarOrRVar -> Func t n a -> IO (Func t n a)
gpuLanes = forall (n :: Nat) b a1 a2 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
DeviceAPI -> VarOrRVar -> f n a -> IO (f n a)
gpuLanes
computeWith :: forall (t :: LoopLevelTy).
LoopAlignStrategy -> Func t n a -> LoopLevel t -> IO ()
computeWith LoopAlignStrategy
a Func t n a
f LoopLevel t
l = forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n a
f forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Stage n a
f' -> forall (f :: Nat -> * -> *) (n :: Nat) a (t :: LoopLevelTy).
Schedulable f n a =>
LoopAlignStrategy -> f n a -> LoopLevel t -> IO ()
computeWith LoopAlignStrategy
a Stage n a
f' LoopLevel t
l
instance Enum TailStrategy where
fromEnum :: TailStrategy -> Int
fromEnum =
forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
TailStrategy
TailRoundUp -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::RoundUp) } |]
TailStrategy
TailGuardWithIf -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::GuardWithIf) } |]
TailStrategy
TailPredicate -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::Predicate) } |]
TailStrategy
TailPredicateLoads -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateLoads) } |]
TailStrategy
TailPredicateStores -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateStores) } |]
TailStrategy
TailShiftInwards -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::ShiftInwards) } |]
TailStrategy
TailAuto -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::Auto) } |]
toEnum :: Int -> TailStrategy
toEnum Int
k
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::RoundUp) } |] = TailStrategy
TailRoundUp
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::GuardWithIf) } |] = TailStrategy
TailGuardWithIf
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::Predicate) } |] = TailStrategy
TailPredicate
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateLoads) } |] = TailStrategy
TailPredicateLoads
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateStores) } |] = TailStrategy
TailPredicateStores
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::ShiftInwards) } |] = TailStrategy
TailShiftInwards
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::Auto) } |] = TailStrategy
TailAuto
| Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"invalid TailStrategy: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k
estimate
:: (KnownNat n, IsHalideType a)
=> Expr Int32
-> Expr Int32
-> Expr Int32
-> Func t n a
-> IO ()
estimate :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
VarOrRVar -> VarOrRVar -> VarOrRVar -> Func t n a -> IO ()
estimate VarOrRVar
var VarOrRVar
min VarOrRVar
extent Func t n a
func =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f -> forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
i -> forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
min forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
minExpr -> forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
extent forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
extentExpr ->
[CU.exp| void {
$(Halide::Func* f)->set_estimate(
*$(Halide::Var* i), *$(Halide::Expr* minExpr), *$(Halide::Expr* extentExpr)) } |]
bound
:: (KnownNat n, IsHalideType a)
=> Expr Int32
-> Expr Int32
-> Expr Int32
-> Func t n a
-> IO ()
bound :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
VarOrRVar -> VarOrRVar -> VarOrRVar -> Func t n a -> IO ()
bound VarOrRVar
var VarOrRVar
min VarOrRVar
extent Func t n a
func =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f -> forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
i -> forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
min forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
minExpr -> forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
extent forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
extentExpr ->
[CU.exp| void {
$(Halide::Func* f)->bound(
*$(Halide::Var* i), *$(Halide::Expr* minExpr), *$(Halide::Expr* extentExpr)) } |]
getArgs :: (KnownNat n, IsHalideType a) => Func t n a -> IO [Var]
getArgs :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO [VarOrRVar]
getArgs Func t n a
func =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' -> do
let allocate :: IO (Ptr (CxxVector CxxVar))
allocate =
[CU.exp| std::vector<Halide::Var>* {
new std::vector<Halide::Var>{$(const Halide::Func* func')->args()} } |]
destroy :: Ptr (CxxVector CxxVar) -> IO ()
destroy Ptr (CxxVector CxxVar)
v = [CU.exp| void { delete $(std::vector<Halide::Var>* v) } |]
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr (CxxVector CxxVar))
allocate Ptr (CxxVector CxxVar) -> IO ()
destroy forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVar)
v -> do
CSize
n <- [CU.exp| size_t { $(const std::vector<Halide::Var>* v)->size() } |]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CSize
0 .. CSize
n forall a. Num a => a -> a -> a
- CSize
1] forall a b. (a -> b) -> a -> b
$ \CSize
i ->
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (a :: k). ForeignPtr CxxVar -> Expr a
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
ptr ->
[CU.exp| void {
new ($(Halide::Var* ptr)) Halide::Var{$(const std::vector<Halide::Var>* v)->at($(size_t i))} } |]
computeRoot :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Func t n a)
computeRoot :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Func t n a)
computeRoot Func t n a
func = do
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
[C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Func* f)->compute_root(); }); } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
func
asUsedBy
:: (KnownNat n, KnownNat m, IsHalideType a, IsHalideType b)
=> Func t1 n a
-> Func 'FuncTy m b
-> IO (Func 'FuncTy n a)
asUsedBy :: forall (n :: Nat) (m :: Nat) a b (t1 :: FuncTy).
(KnownNat n, KnownNat m, IsHalideType a, IsHalideType b) =>
Func t1 n a -> Func 'FuncTy m b -> IO (Func 'FuncTy n a)
asUsedBy Func t1 n a
g Func 'FuncTy m b
f =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t1 n a
g forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
gPtr -> forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy m b
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
fPtr ->
forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* {
new Halide::Func{$(Halide::Func* gPtr)->in(*$(Halide::Func* fPtr))} } |]
asUsed :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Func 'FuncTy n a)
asUsed :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Func 'FuncTy n a)
asUsed Func t n a
f =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
fPtr ->
forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* { new Halide::Func{$(Halide::Func* fPtr)->in()} } |]
copyToDevice :: (KnownNat n, IsHalideType a) => DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice DeviceAPI
deviceApi Func t n a
func = do
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Func* f)->copy_to_device(static_cast<Halide::DeviceAPI>($(int api)));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
func
where
api :: CInt
api = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ DeviceAPI
deviceApi
copyToHost :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Func t n a)
copyToHost :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Func t n a)
copyToHost = forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice DeviceAPI
DeviceHost
mkBufferParameter
:: forall n a. (KnownNat n, IsHalideType a) => Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter Maybe Text
maybeName = do
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @a)) forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
let d :: CInt
d = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)
createWithoutName :: IO (Ptr CxxImageParam)
createWithoutName =
[CU.exp| Halide::ImageParam* {
new Halide::ImageParam{Halide::Type{*$(halide_type_t* t)}, $(int d)} } |]
deleter :: FunPtr (Ptr CxxImageParam -> IO ())
deleter = [C.funPtr| void deleteImageParam(Halide::ImageParam* p) { delete p; } |]
createWithName :: Text -> IO (Ptr CxxImageParam)
createWithName Text
name =
let s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
in [CU.exp| Halide::ImageParam* {
new Halide::ImageParam{
Halide::Type{*$(halide_type_t* t)},
$(int d),
std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}} } |]
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxImageParam -> IO ())
deleter forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO (Ptr CxxImageParam)
createWithoutName Text -> IO (Ptr CxxImageParam)
createWithName Maybe Text
maybeName
getBufferParameter
:: forall n a
. (KnownNat n, IsHalideType a)
=> Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter Maybe Text
name IORef (Maybe (ForeignPtr CxxImageParam))
r =
forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxImageParam))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just ForeignPtr CxxImageParam
fp -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxImageParam
fp
Maybe (ForeignPtr CxxImageParam)
Nothing -> do
ForeignPtr CxxImageParam
fp <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter @n @a Maybe Text
name
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (ForeignPtr CxxImageParam))
r (forall a. a -> Maybe a
Just ForeignPtr CxxImageParam
fp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxImageParam
fp
withBufferParam
:: forall n a b
. (HasCallStack, KnownNat n, IsHalideType a)
=> Func 'ParamTy n a
-> (Ptr CxxImageParam -> IO b)
-> IO b
withBufferParam :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n a -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam (Param IORef (Maybe (ForeignPtr CxxImageParam))
r) Ptr CxxImageParam -> IO b
action =
forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxImageParam))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr Ptr CxxImageParam -> IO b
action
withFunc :: (KnownNat n, IsHalideType a) => Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc :: forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> ForeignPtr CxxFunc
funcToForeignPtr Func t n a
f)
wrapCxxFunc :: Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc :: forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (n :: Nat) a. ForeignPtr CxxFunc -> Func 'FuncTy n a
Func forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxFunc -> IO ())
deleter
where
deleter :: FunPtr (Ptr CxxFunc -> IO ())
deleter = [C.funPtr| void deleteFunc(Halide::Func *x) { delete x; } |]
forceFunc :: forall t n a. (KnownNat n, IsHalideType a) => Func t n a -> IO (Func 'FuncTy n a)
forceFunc :: forall (t :: FuncTy) (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Func 'FuncTy n a)
forceFunc x :: Func t n a
x@(Func ForeignPtr CxxFunc
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
x
forceFunc (Param IORef (Maybe (ForeignPtr CxxImageParam))
r) = do
ForeignPtr CxxImageParam
fp <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxImageParam))
r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxImageParam
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
p ->
forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* {
new Halide::Func{static_cast<Halide::Func>(*$(Halide::ImageParam* p))} } |]
funcToForeignPtr :: (KnownNat n, IsHalideType a) => Func t n a -> ForeignPtr CxxFunc
funcToForeignPtr :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> ForeignPtr CxxFunc
funcToForeignPtr Func t n a
x = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$! forall (t :: FuncTy) (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Func 'FuncTy n a)
forceFunc Func t n a
x forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Func ForeignPtr CxxFunc
fp) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxFunc
fp
define
:: ( IsTuple (Arguments ts) i
, All ((~) Var) ts
, Length ts ~ n
, KnownNat n
, IsHalideType a
)
=> Text
-> i
-> Expr a
-> IO (Func 'FuncTy n a)
define :: forall (ts :: [*]) i (n :: Nat) a.
(IsTuple (Arguments ts) i, All ((~) VarOrRVar) ts, Length ts ~ n,
KnownNat n, IsHalideType a) =>
Text -> i -> Expr a -> IO (Func 'FuncTy n a)
define Text
name i
args Expr a
expr =
forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar (forall a t. IsTuple a t => t -> a
fromTuple i
args) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVar)
x -> do
let s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
y ->
forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.block| Halide::Func* {
Halide::Func f{std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}};
f(*$(std::vector<Halide::Var>* x)) = *$(Halide::Expr* y);
return new Halide::Func{f};
} |]
update
:: ( IsTuple (Arguments ts) i
, All ((~) (Expr Int32)) ts
, Length ts ~ n
, KnownNat n
, IsHalideType a
)
=> Func 'FuncTy n a
-> i
-> Expr a
-> IO ()
update :: forall (ts :: [*]) i (n :: Nat) a.
(IsTuple (Arguments ts) i, All ((~) VarOrRVar) ts, Length ts ~ n,
KnownNat n, IsHalideType a) =>
Func 'FuncTy n a -> i -> Expr a -> IO ()
update Func 'FuncTy n a
func i
args Expr a
expr =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr (forall a t. IsTuple a t => t -> a
fromTuple i
args) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
x ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
y ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Func* f)->operator()(*$(std::vector<Halide::Expr>* x)) = *$(Halide::Expr* y);
});
} |]
infix 9 !
(!)
:: ( IsTuple (Arguments ts) i
, All ((~) (Expr Int32)) ts
, Length ts ~ n
, KnownNat n
, IsHalideType a
)
=> Func t n a
-> i
-> Expr a
! :: forall (ts :: [*]) i (n :: Nat) a (t :: FuncTy).
(IsTuple (Arguments ts) i, All ((~) VarOrRVar) ts, Length ts ~ n,
KnownNat n, IsHalideType a) =>
Func t n a -> i -> Expr a
(!) Func t n a
func i
args =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr (forall a t. IsTuple a t => t -> a
fromTuple i
args) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
x ->
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
$(Halide::Func* f)->operator()(*$(std::vector<Halide::Expr>* x))} } |]
dim
:: forall n a
. (HasCallStack, KnownNat n, IsHalideType a)
=> Int
-> Func 'ParamTy n a
-> IO Dimension
dim :: forall (n :: Nat) a.
(HasCallStack, KnownNat n, IsHalideType a) =>
Int -> Func 'ParamTy n a -> IO Dimension
dim Int
k Func 'ParamTy n a
func
| Int
0 forall a. Ord a => a -> a -> Bool
<= Int
k Bool -> Bool -> Bool
&& Int
k forall a. Ord a => a -> a -> Bool
< forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) =
let n :: CInt
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
in forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n a -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam Func 'ParamTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
f ->
Ptr CxxDimension -> IO Dimension
wrapCxxDimension
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Internal::Dimension* {
new Halide::Internal::Dimension{$(Halide::ImageParam* f)->dim($(int n))} } |]
| Bool
otherwise =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"invalid dimension index: "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k
forall a. Semigroup a => a -> a -> a
<> String
"; Func is "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
forall a. Semigroup a => a -> a -> a
<> String
"-dimensional"
prettyLoopNest :: (KnownNat n, IsHalideType r) => Func t n r -> IO Text
prettyLoopNest :: forall (n :: Nat) r (t :: FuncTy).
(KnownNat n, IsHalideType r) =>
Func t n r -> IO Text
prettyLoopNest Func t n r
func = forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n r
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
Ptr CxxString -> IO Text
peekAndDeleteCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| std::string* {
return handle_halide_exceptions([=]() {
return new std::string{Halide::Internal::print_loop_nest(
std::vector<Halide::Internal::Function>{$(Halide::Func* f)->function()})};
});
} |]
realize
:: forall n a t b
. (KnownNat n, IsHalideType a)
=> Func t n a
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
realize :: forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
realize Func t n a
func [Int]
shape Ptr (HalideBuffer n a) -> IO b
action =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
[Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaCpuBuffer [Int]
shape forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer n a)
buf -> do
let raw :: Ptr RawHalideBuffer
raw = forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer n a)
buf
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Func* f)->realize(
Halide::Pipeline::RealizationArg{$(halide_buffer_t* raw)});
});
} |]
Ptr (HalideBuffer n a) -> IO b
action Ptr (HalideBuffer n a)
buf
buffer :: forall n a. (KnownNat n, IsHalideType a) => Text -> Func 'ParamTy n a -> Func 'ParamTy n a
buffer :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Text -> Func 'ParamTy n a -> Func 'ParamTy n a
buffer Text
name p :: Func 'ParamTy n a
p@(Param IORef (Maybe (ForeignPtr CxxImageParam))
r) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr CxxImageParam
_ <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a (forall a. a -> Maybe a
Just Text
name) IORef (Maybe (ForeignPtr CxxImageParam))
r
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'ParamTy n a
p
scalar :: forall a. IsHalideType a => Text -> Expr a -> Expr a
scalar :: forall a. IsHalideType a => Text -> Expr a -> Expr a
scalar Text
name (ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxParameter))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just ForeignPtr CxxParameter
_ -> forall a. HasCallStack => String -> a
error String
"the name of this Expr has already been set"
Maybe (ForeignPtr CxxParameter)
Nothing -> do
ForeignPtr CxxParameter
fp <- forall a.
IsHalideType a =>
Maybe Text -> IO (ForeignPtr CxxParameter)
mkScalarParameter @a (forall a. a -> Maybe a
Just Text
name)
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (ForeignPtr CxxParameter))
r (forall a. a -> Maybe a
Just ForeignPtr CxxParameter
fp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k).
IORef (Maybe (ForeignPtr CxxParameter)) -> Expr a
ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r)
scalar Text
_ Expr a
_ = forall a. HasCallStack => String -> a
error String
"cannot set the name of an expression that is not a parameter"
wrapCxxStage :: (KnownNat n, IsHalideType a) => Ptr CxxStage -> IO (Stage n a)
wrapCxxStage :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Ptr CxxStage -> IO (Stage n a)
wrapCxxStage = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (n :: Nat) a. ForeignPtr CxxStage -> Stage n a
Stage forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxStage -> IO ())
deleter
where
deleter :: FunPtr (Ptr CxxStage -> IO ())
deleter = [C.funPtr| void deleteStage(Halide::Stage* p) { delete p; } |]
withCxxStage :: (KnownNat n, IsHalideType a) => Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage :: forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage (Stage ForeignPtr CxxStage
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxStage
fp
getStage :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Stage n a)
getStage :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
[CU.exp| Halide::Stage* { new Halide::Stage{static_cast<Halide::Stage>(*$(Halide::Func* func'))} } |]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Ptr CxxStage -> IO (Stage n a)
wrapCxxStage
hasUpdateDefinitions :: (KnownNat n, IsHalideType a) => Func t n a -> IO Bool
hasUpdateDefinitions :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO Bool
hasUpdateDefinitions Func t n a
func =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Func* func')->has_update_definition() } |]
getUpdateStage :: (KnownNat n, IsHalideType a) => Int -> Func 'FuncTy n a -> IO (Stage n a)
getUpdateStage :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Int -> Func 'FuncTy n a -> IO (Stage n a)
getUpdateStage Int
k Func 'FuncTy n a
func =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
let k' :: CInt
k' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
in [CU.exp| Halide::Stage* { new Halide::Stage{$(Halide::Func* func')->update($(int k'))} } |]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Ptr CxxStage -> IO (Stage n a)
wrapCxxStage
getLoopLevelAtStage
:: (KnownNat n, IsHalideType a)
=> Func t n a
-> Expr Int32
-> Int
-> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> VarOrRVar -> Int -> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage Func t n a
func VarOrRVar
var Int
stageIndex =
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f -> forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
i -> do
(SomeLoopLevel LoopLevel t
level) <-
Ptr CxxLoopLevel -> IO SomeLoopLevel
wrapCxxLoopLevel
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::LoopLevel* {
return handle_halide_exceptions([=](){
return new Halide::LoopLevel{*$(const Halide::Func* f),
*$(const Halide::VarOrRVar* i),
$(int k)};
});
} |]
case LoopLevel t
level of
LoopLevel ForeignPtr CxxLoopLevel
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure LoopLevel t
level
LoopLevel t
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"getLoopLevelAtStage: got " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show LoopLevel t
level forall a. Semigroup a => a -> a -> a
<> String
", but expected a LoopLevel 'LockedTy"
where
k :: CInt
k = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
stageIndex
getLoopLevel :: (KnownNat n, IsHalideType a) => Func t n a -> Expr Int32 -> IO (LoopLevel 'LockedTy)
getLoopLevel :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> VarOrRVar -> IO (LoopLevel 'LockedTy)
getLoopLevel Func t n a
f VarOrRVar
i = forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> VarOrRVar -> Int -> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage Func t n a
f VarOrRVar
i (-Int
1)
storeAt :: (KnownNat n, IsHalideType a) => Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
storeAt :: forall (n :: Nat) a (t :: LoopLevelTy).
(KnownNat n, IsHalideType a) =>
Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
storeAt Func 'FuncTy n a
func LoopLevel t
level = do
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l ->
[CU.exp| void { $(Halide::Func* f)->store_at(*$(const Halide::LoopLevel* l)) } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func
computeAt :: (KnownNat n, IsHalideType a) => Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
computeAt :: forall (n :: Nat) a (t :: LoopLevelTy).
(KnownNat n, IsHalideType a) =>
Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
computeAt Func 'FuncTy n a
func LoopLevel t
level = do
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l ->
[CU.exp| void { $(Halide::Func* f)->compute_at(*$(const Halide::LoopLevel* l)) } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func
asBufferParam
:: forall n a t b
. IsHalideBuffer t n a
=> t
-> (Func 'ParamTy n a -> IO b)
-> IO b
asBufferParam :: forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Func 'ParamTy n a -> IO b) -> IO b
asBufferParam t
arr Func 'ParamTy n a -> IO b
action =
forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer @n @a t
arr forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer n a)
arr' -> do
ForeignPtr CxxImageParam
param <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter @n @a forall a. Maybe a
Nothing
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxImageParam
param forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
param' ->
let buf :: Ptr RawHalideBuffer
buf = (forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer n a)
arr' :: Ptr RawHalideBuffer)
in [CU.block| void {
$(Halide::ImageParam* param')->set(Halide::Buffer<>{*$(const halide_buffer_t* buf)});
} |]
Func 'ParamTy n a -> IO b
action forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) a.
IORef (Maybe (ForeignPtr CxxImageParam)) -> Func 'ParamTy n a
Param forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. a -> IO (IORef a)
newIORef (forall a. a -> Maybe a
Just ForeignPtr CxxImageParam
param)