{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} -- {-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -Wno-redundant-constraints #-} -- | -- Module : Language.Halide.Func -- Description : Functions / Arrays -- Copyright : (c) Tom Westerhout, 2023 module Language.Halide.Func ( -- * Defining pipelines Func (..) , FuncTy (..) , Stage (..) , buffer , scalar , define , (!) , realize -- * Scheduling , Schedulable (..) , TailStrategy (..) -- ** 'Func'-specific , computeRoot , getStage , getLoopLevel , getLoopLevelAtStage , asUsed , asUsedBy , copyToDevice , copyToHost , storeAt , computeAt , dim , estimate , bound , getArgs -- , deepCopy -- * Update definitions , update , hasUpdateDefinitions , getUpdateStage -- * Debugging , prettyLoopNest -- * Internal , IndexTuple , asBufferParam , withFunc , withBufferParam , wrapCxxFunc , CxxStage , wrapCxxStage , withCxxStage ) where import Control.Exception (bracket) import Control.Monad (forM) import Data.IORef import Data.Kind (Type) import Data.Proxy import Data.Text (Text) import Data.Text.Encoding qualified as T import Foreign.ForeignPtr import Foreign.Marshal (toBool, with) import Foreign.Ptr (Ptr, castPtr) import GHC.Stack (HasCallStack) import GHC.TypeLits import Language.C.Inline qualified as C import Language.C.Inline.Cpp.Exception qualified as C import Language.C.Inline.Unsafe qualified as CU import Language.Halide.Buffer import Language.Halide.Context import Language.Halide.Dimension import Language.Halide.Expr import Language.Halide.LoopLevel import Language.Halide.Target import Language.Halide.Type import Language.Halide.Utils import System.IO.Unsafe (unsafePerformIO) import Prelude hiding (min, tail) -- | Haskell counterpart of [Halide::Stage](https://halide-lang.org/docs/class_halide_1_1_stage.html). data CxxStage importHalide -- | A function in Halide. Conceptually, it can be thought of as a lazy -- @n@-dimensional buffer of type @a@. -- -- This is a wrapper around the [@Halide::Func@](https://halide-lang.org/docs/class_halide_1_1_func.html) -- C++ type. data Func (t :: FuncTy) (n :: Nat) (a :: Type) where Func :: {-# UNPACK #-} !(ForeignPtr CxxFunc) -> Func 'FuncTy n a Param :: {-# UNPACK #-} !(IORef (Maybe (ForeignPtr CxxImageParam))) -> Func 'ParamTy n a -- | Function type. It can either be 'FuncTy' which means that we have defined the function ourselves, -- or 'ParamTy' which means that it's a parameter to our pipeline. data FuncTy = FuncTy | ParamTy deriving stock (Show, Eq, Ord) -- | A single definition of a t'Func'. newtype Stage (n :: Nat) (a :: Type) = Stage (ForeignPtr CxxStage) -- | Different ways to handle a tail case in a split when the split factor does -- not provably divide the extent. -- -- This is the Haskell counterpart of [@Halide::TailStrategy@](https://halide-lang.org/docs/namespace_halide.html#a6c6557df562bd7850664e70fdb8fea0f). data TailStrategy = -- | Round up the extent to be a multiple of the split factor. -- -- Not legal for RVars, as it would change the meaning of the algorithm. -- -- * Pros: generates the simplest, fastest code. -- * Cons: if used on a stage that reads from the input or writes to the -- output, constrains the input or output size to be a multiple of the -- split factor. TailRoundUp | -- | Guard the inner loop with an if statement that prevents evaluation -- beyond the original extent. -- -- Always legal. The if statement is treated like a boundary condition, and -- factored out into a loop epilogue if possible. -- -- * Pros: no redundant re-evaluation; does not constrain input our output sizes. -- * Cons: increases code size due to separate tail-case handling; -- vectorization will scalarize in the tail case to handle the if -- statement. TailGuardWithIf | -- | Guard the loads and stores in the loop with an if statement that -- prevents evaluation beyond the original extent. -- -- Always legal. The if statement is treated like a boundary condition, and -- factored out into a loop epilogue if possible. -- * Pros: no redundant re-evaluation; does not constrain input or output -- sizes. -- * Cons: increases code size due to separate tail-case handling. TailPredicate | -- | Guard the loads in the loop with an if statement that prevents -- evaluation beyond the original extent. -- -- Only legal for innermost splits. Not legal for RVars, as it would change -- the meaning of the algorithm. The if statement is treated like a -- boundary condition, and factored out into a loop epilogue if possible. -- * Pros: does not constrain input sizes, output size constraints are -- simpler than full predication. -- * Cons: increases code size due to separate tail-case handling, -- constrains the output size to be a multiple of the split factor. TailPredicateLoads | -- | Guard the stores in the loop with an if statement that prevents -- evaluation beyond the original extent. -- -- Only legal for innermost splits. Not legal for RVars, as it would change -- the meaning of the algorithm. The if statement is treated like a -- boundary condition, and factored out into a loop epilogue if possible. -- * Pros: does not constrain output sizes, input size constraints are -- simpler than full predication. -- * Cons: increases code size due to separate tail-case handling, -- constraints the input size to be a multiple of the split factor. TailPredicateStores | -- | Prevent evaluation beyond the original extent by shifting the tail -- case inwards, re-evaluating some points near the end. -- -- Only legal for pure variables in pure definitions. If the inner loop is -- very simple, the tail case is treated like a boundary condition and -- factored out into an epilogue. -- -- This is a good trade-off between several factors. Like 'TailRoundUp', it -- supports vectorization well, because the inner loop is always a fixed -- size with no data-dependent branching. It increases code size slightly -- for inner loops due to the epilogue handling, but not for outer loops -- (e.g. loops over tiles). If used on a stage that reads from an input or -- writes to an output, this stategy only requires that the input/output -- extent be at least the split factor, instead of a multiple of the split -- factor as with 'TailRoundUp'. TailShiftInwards | -- | For pure definitions use 'TailShiftInwards'. -- -- For pure vars in update definitions use 'TailRoundUp'. For RVars in update -- definitions use 'TailGuardWithIf'. TailAuto deriving stock (Eq, Ord, Show) -- | Specifies that @i@ is a tuple of @'Expr' Int32@. -- -- @ts@ are deduced from @i@, so you don't have to specify them explicitly. type IndexTuple i ts = (IsTuple (Arguments ts) i, All ((~) (Expr Int32)) ts) -- | Common scheduling functions class (KnownNat n, IsHalideType a) => Schedulable f n a where -- | Vectorize the dimension. vectorize :: VarOrRVar -> f n a -> IO (f n a) -- | Unroll the dimension. unroll :: VarOrRVar -> f n a -> IO (f n a) -- | Reorder variables to have the given nesting order, from innermost out. reorder :: [VarOrRVar] -> f n a -> IO (f n a) -- | Split a dimension into inner and outer subdimensions with the given names, where the inner dimension -- iterates from @0@ to @factor-1@. -- -- The inner and outer subdimensions can then be dealt with using the other scheduling calls. It's okay -- to reuse the old variable name as either the inner or outer variable. The first argument specifies -- how the tail should be handled if the split factor does not provably divide the extent. split :: TailStrategy -> VarOrRVar -> (VarOrRVar, VarOrRVar) -> Expr Int32 -> f n a -> IO (f n a) -- | Join two dimensions into a single fused dimenion. -- -- The fused dimension covers the product of the extents of the inner and outer dimensions given. fuse :: (VarOrRVar, VarOrRVar) -> VarOrRVar -> f n a -> IO (f n a) -- | Mark the dimension to be traversed serially serial :: VarOrRVar -> f n a -> IO (f n a) -- | Mark the dimension to be traversed in parallel parallel :: VarOrRVar -> f n a -> IO (f n a) specialize :: Expr Bool -> f n a -> IO (Stage n a) specializeFail :: Text -> f n a -> IO () gpuBlocks :: (IndexTuple i ts, 1 <= Length ts, Length ts <= 3) => DeviceAPI -> i -> f n a -> IO (f n a) gpuThreads :: (IndexTuple i ts, 1 <= Length ts, Length ts <= 3) => DeviceAPI -> i -> f n a -> IO (f n a) gpuLanes :: DeviceAPI -> VarOrRVar -> f n a -> IO (f n a) -- | Schedule the iteration over this stage to be fused with another stage from outermost loop to a -- given LoopLevel. -- -- For more info, see [Halide::Stage::compute_with](https://halide-lang.org/docs/class_halide_1_1_stage.html#a82a2ae25a009d6a2d52cb407a25f0a5b). computeWith :: LoopAlignStrategy -> f n a -> LoopLevel t -> IO () instance (KnownNat n, IsHalideType a) => Schedulable Stage n a where vectorize var stage = do withCxxStage stage $ \stage' -> asVarOrRVar var $ \var' -> [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Stage* stage')->vectorize(*$(const Halide::VarOrRVar* var')); }); } |] pure stage unroll var stage = do withCxxStage stage $ \stage' -> asVarOrRVar var $ \var' -> [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Stage* stage')->unroll(*$(const Halide::VarOrRVar* var')); }); } |] pure stage reorder args stage = do withMany asVarOrRVar args $ \args' -> do withCxxStage stage $ \stage' -> [C.throwBlock| void { handle_halide_exceptions([=]() { $(Halide::Stage* stage')->reorder( *$(const std::vector* args')); }); } |] pure stage split tail old (outer, inner) factor stage = do withCxxStage stage $ \stage' -> asVarOrRVar old $ \old' -> asVarOrRVar outer $ \outer' -> asVarOrRVar inner $ \inner' -> asExpr factor $ \factor' -> [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Stage* stage')->split( *$(const Halide::VarOrRVar* old'), *$(const Halide::VarOrRVar* outer'), *$(const Halide::VarOrRVar* inner'), *$(const Halide::Expr* factor'), static_cast($(int t))); }); } |] pure stage where t = fromIntegral . fromEnum $ tail fuse (outer, inner) fused stage = do withCxxStage stage $ \stage' -> asVarOrRVar outer $ \outer' -> asVarOrRVar inner $ \inner' -> asVarOrRVar fused $ \fused' -> [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Stage* stage')->fuse( *$(const Halide::VarOrRVar* outer'), *$(const Halide::VarOrRVar* inner'), *$(const Halide::VarOrRVar* fused')); }); } |] pure stage serial var stage = do withCxxStage stage $ \stage' -> asVarOrRVar var $ \var' -> [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Stage* stage')->serial(*$(const Halide::VarOrRVar* var')); }); } |] pure stage parallel var stage = do withCxxStage stage $ \stage' -> asVarOrRVar var $ \var' -> [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Stage* stage')->parallel(*$(const Halide::VarOrRVar* var')); }); } |] pure stage specialize cond stage = do withCxxStage stage $ \stage' -> asExpr cond $ \cond' -> wrapCxxStage =<< [C.throwBlock| Halide::Stage* { return handle_halide_exceptions([=](){ return new Halide::Stage{$(Halide::Stage* stage')->specialize( *$(const Halide::Expr* cond'))}; }); } |] specializeFail (T.encodeUtf8 -> s) stage = withCxxStage stage $ \stage' -> [C.throwBlock| void { return handle_halide_exceptions([=](){ $(Halide::Stage* stage')->specialize_fail( std::string{$bs-ptr:s, static_cast($bs-len:s)}); }); } |] gpuBlocks (fromIntegral . fromEnum -> api) vars stage = do withCxxStage stage $ \stage' -> asVectorOf @((~) (Expr Int32)) asVarOrRVar (fromTuple vars) $ \vars' -> do [C.throwBlock| void { handle_halide_exceptions([=](){ auto const& vars = *$(const std::vector* vars'); auto& stage = *$(Halide::Stage* stage'); auto const device = static_cast($(int api)); switch (vars.size()) { case 1: stage.gpu_blocks(vars.at(0), device); break; case 2: stage.gpu_blocks(vars.at(0), vars.at(1), device); break; case 3: stage.gpu_blocks(vars.at(0), vars.at(1), vars.at(2), device); break; default: throw std::runtime_error{"unexpected number of arguments in gpuBlocks"}; } }); } |] pure stage gpuThreads (fromIntegral . fromEnum -> api) vars stage = do withCxxStage stage $ \stage' -> asVectorOf @((~) (Expr Int32)) asVarOrRVar (fromTuple vars) $ \vars' -> do [C.throwBlock| void { handle_halide_exceptions([=](){ auto const& vars = *$(const std::vector* vars'); auto& stage = *$(Halide::Stage* stage'); auto const device = static_cast($(int api)); switch (vars.size()) { case 1: stage.gpu_threads(vars.at(0), device); break; case 2: stage.gpu_threads(vars.at(0), vars.at(1), device); break; case 3: stage.gpu_threads(vars.at(0), vars.at(1), vars.at(2), device); break; default: throw std::runtime_error{"unexpected number of arguments in gpuThreads"}; } }); } |] pure stage gpuLanes (fromIntegral . fromEnum -> api) var stage = do withCxxStage stage $ \stage' -> asVarOrRVar var $ \var' -> [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Stage* stage')->gpu_lanes( *$(const Halide::VarOrRVar* var'), static_cast($(int api))); }); } |] pure stage computeWith (fromIntegral . fromEnum -> align) stage level = do withCxxStage stage $ \stage' -> withCxxLoopLevel level $ \level' -> [C.throwBlock| void { handle_halide_exceptions([=]() { $(Halide::Stage* stage')->compute_with( *$(const Halide::LoopLevel* level'), static_cast($(int align))); }); } |] viaStage1 :: (KnownNat n, IsHalideType b) => (a -> Stage n b -> IO (Stage n b)) -> a -> Func t n b -> IO (Func t n b) viaStage1 f a1 func = do _ <- f a1 =<< getStage func pure func viaStage2 :: (KnownNat n, IsHalideType b) => (a1 -> a2 -> Stage n b -> IO (Stage n b)) -> a1 -> a2 -> Func t n b -> IO (Func t n b) viaStage2 f a1 a2 func = do _ <- f a1 a2 =<< getStage func pure func {- viaStage3 :: (KnownNat n, IsHalideType b) => (a1 -> a2 -> a3 -> Stage n b -> IO (Stage n b)) -> a1 -> a2 -> a3 -> Func t n b -> IO (Func t n b) viaStage3 f a1 a2 a3 func = do _ <- f a1 a2 a3 =<< getStage func pure func -} viaStage4 :: (KnownNat n, IsHalideType b) => (a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b)) -> a1 -> a2 -> a3 -> a4 -> Func t n b -> IO (Func t n b) viaStage4 f a1 a2 a3 a4 func = do _ <- f a1 a2 a3 a4 =<< getStage func pure func instance (KnownNat n, IsHalideType a) => Schedulable (Func t) n a where vectorize = viaStage1 vectorize unroll = viaStage1 unroll reorder = viaStage1 reorder split = viaStage4 split fuse = viaStage2 fuse serial = viaStage1 serial parallel = viaStage1 parallel specialize cond func = getStage func >>= specialize cond specializeFail msg func = getStage func >>= specializeFail msg gpuBlocks = viaStage2 gpuBlocks gpuThreads = viaStage2 gpuThreads gpuLanes = viaStage2 gpuLanes computeWith a f l = getStage f >>= \f' -> computeWith a f' l instance Enum TailStrategy where fromEnum = fromIntegral . \case TailRoundUp -> [CU.pure| int { static_cast(Halide::TailStrategy::RoundUp) } |] TailGuardWithIf -> [CU.pure| int { static_cast(Halide::TailStrategy::GuardWithIf) } |] TailPredicate -> [CU.pure| int { static_cast(Halide::TailStrategy::Predicate) } |] TailPredicateLoads -> [CU.pure| int { static_cast(Halide::TailStrategy::PredicateLoads) } |] TailPredicateStores -> [CU.pure| int { static_cast(Halide::TailStrategy::PredicateStores) } |] TailShiftInwards -> [CU.pure| int { static_cast(Halide::TailStrategy::ShiftInwards) } |] TailAuto -> [CU.pure| int { static_cast(Halide::TailStrategy::Auto) } |] toEnum k | fromIntegral k == [CU.pure| int { static_cast(Halide::TailStrategy::RoundUp) } |] = TailRoundUp | fromIntegral k == [CU.pure| int { static_cast(Halide::TailStrategy::GuardWithIf) } |] = TailGuardWithIf | fromIntegral k == [CU.pure| int { static_cast(Halide::TailStrategy::Predicate) } |] = TailPredicate | fromIntegral k == [CU.pure| int { static_cast(Halide::TailStrategy::PredicateLoads) } |] = TailPredicateLoads | fromIntegral k == [CU.pure| int { static_cast(Halide::TailStrategy::PredicateStores) } |] = TailPredicateStores | fromIntegral k == [CU.pure| int { static_cast(Halide::TailStrategy::ShiftInwards) } |] = TailShiftInwards | fromIntegral k == [CU.pure| int { static_cast(Halide::TailStrategy::Auto) } |] = TailAuto | otherwise = error $ "invalid TailStrategy: " <> show k -- vectorize -- :: (KnownNat n, IsHalideType a) -- => TailStrategy -- -> Func t n a -- -> Expr Int32 -- -- ^ Variable to vectorize -- -> Expr Int32 -- -- ^ Split factor -- -> IO () -- vectorize strategy func var factor = -- withFunc func $ \f -> -- asVarOrRVar var $ \x -> -- asExpr factor $ \n -> -- [C.throwBlock| void { -- $(Halide::Func* f)->vectorize(*$(Halide::VarOrRVar* x), *$(Halide::Expr* n), -- static_cast($(int tail))); -- } |] -- where -- tail = fromIntegral (fromEnum strategy) -- | Split a dimension by the given factor, then unroll the inner dimension. -- -- This is how you unroll a loop of unknown size by some constant factor. After -- this call, @var@ refers to the outer dimension of the split. -- unroll -- :: (KnownNat n, IsHalideType a) -- => TailStrategy -- -> Func t n a -- -> Expr Int32 -- -- ^ Variable @var@ to vectorize -- -> Expr Int32 -- -- ^ Split factor -- -> IO () -- unroll strategy func var factor = -- withFunc func $ \f -> -- asVarOrRVar var $ \x -> -- asExpr factor $ \n -> -- [C.throwBlock| void { -- $(Halide::Func* f)->unroll(*$(Halide::VarOrRVar* x), *$(Halide::Expr* n), -- static_cast($(int tail))); -- } |] -- where -- tail = fromIntegral (fromEnum strategy) -- | Reorder variables to have the given nesting order, from innermost out. -- reorder -- :: forall t n a i ts -- . ( IsTuple (Arguments ts) i -- , All ((~) (Expr Int32)) ts -- , Length ts ~ n -- , KnownNat n -- , IsHalideType a -- ) -- => Func t n a -- -> i -- -> IO () -- reorder func args = -- asVectorOf @((~) (Expr Int32)) asVarOrRVar (fromTuple args) $ \v -> do -- withFunc func $ \f -> -- [C.throwBlock| void { $(Halide::Func* f)->reorder(*$(std::vector* v)); } |] -- | Statically declare the range over which the function will be evaluated in the general case. -- -- This provides a basis for the auto scheduler to make trade-offs and scheduling decisions. -- The auto generated schedules might break when the sizes of the dimensions are very different from the -- estimates specified. These estimates are used only by the auto scheduler if the function is a pipeline output. estimate :: (KnownNat n, IsHalideType a) => Expr Int32 -- ^ index variable -> Expr Int32 -- ^ @min@ estimate -> Expr Int32 -- ^ @extent@ estimate -> Func t n a -> IO () estimate var min extent func = withFunc func $ \f -> asVar var $ \i -> asExpr min $ \minExpr -> asExpr extent $ \extentExpr -> [CU.exp| void { $(Halide::Func* f)->set_estimate( *$(Halide::Var* i), *$(Halide::Expr* minExpr), *$(Halide::Expr* extentExpr)) } |] -- | Statically declare the range over which a function should be evaluated. -- -- This can let Halide perform some optimizations. E.g. if you know there are going to be 4 color channels, -- you can completely vectorize the color channel dimension without the overhead of splitting it up. -- If bounds inference decides that it requires more of this function than the bounds you have stated, -- a runtime error will occur when you try to run your pipeline. bound :: (KnownNat n, IsHalideType a) => Expr Int32 -- ^ index variable -> Expr Int32 -- ^ @min@ estimate -> Expr Int32 -- ^ @extent@ estimate -> Func t n a -> IO () bound var min extent func = withFunc func $ \f -> asVar var $ \i -> asExpr min $ \minExpr -> asExpr extent $ \extentExpr -> [CU.exp| void { $(Halide::Func* f)->bound( *$(Halide::Var* i), *$(Halide::Expr* minExpr), *$(Halide::Expr* extentExpr)) } |] -- | Get the index arguments of the function. -- -- The returned list contains exactly @n@ elements. getArgs :: (KnownNat n, IsHalideType a) => Func t n a -> IO [Var] getArgs func = withFunc func $ \func' -> do let allocate = [CU.exp| std::vector* { new std::vector{$(const Halide::Func* func')->args()} } |] destroy v = [CU.exp| void { delete $(std::vector* v) } |] bracket allocate destroy $ \v -> do n <- [CU.exp| size_t { $(const std::vector* v)->size() } |] forM [0 .. n - 1] $ \i -> fmap Var . cxxConstruct $ \ptr -> [CU.exp| void { new ($(Halide::Var* ptr)) Halide::Var{$(const std::vector* v)->at($(size_t i))} } |] -- | Compute all of this function once ahead of time. -- -- See [Halide::Func::compute_root](https://halide-lang.org/docs/class_halide_1_1_func.html#a29df45a4a16a63eb81407261a9783060) for more info. computeRoot :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Func t n a) computeRoot func = do withFunc func $ \f -> [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Func* f)->compute_root(); }); } |] pure func -- | Creates and returns a new identity Func that wraps this Func. -- -- During compilation, Halide replaces all calls to this Func done by 'f' with calls to the wrapper. -- If this Func is already wrapped for use in 'f', will return the existing wrapper. -- -- For more info, see [Halide::Func::in](https://halide-lang.org/docs/class_halide_1_1_func.html#a9d619f2d0111ea5bf640781d1324d050). asUsedBy :: (KnownNat n, KnownNat m, IsHalideType a, IsHalideType b) => Func t1 n a -> Func 'FuncTy m b -> IO (Func 'FuncTy n a) asUsedBy g f = withFunc g $ \gPtr -> withFunc f $ \fPtr -> wrapCxxFunc =<< [CU.exp| Halide::Func* { new Halide::Func{$(Halide::Func* gPtr)->in(*$(Halide::Func* fPtr))} } |] -- | Create and return a global identity wrapper, which wraps all calls to this Func by any other Func. -- -- If a global wrapper already exists, returns it. The global identity wrapper is only used by callers -- for which no custom wrapper has been specified. asUsed :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Func 'FuncTy n a) asUsed f = withFunc f $ \fPtr -> wrapCxxFunc =<< [CU.exp| Halide::Func* { new Halide::Func{$(Halide::Func* fPtr)->in()} } |] -- | Declare that this function should be implemented by a call to @halide_buffer_copy@ with the given -- target device API. -- -- Asserts that the @Func@ has a pure definition which is a simple call to a single input, and no update -- definitions. The wrapper @Func@s returned by 'asUsed' are suitable candidates. Consumes all pure variables, -- and rewrites the @Func@ to have an extern definition that calls @halide_buffer_copy@. copyToDevice :: (KnownNat n, IsHalideType a) => DeviceAPI -> Func t n a -> IO (Func t n a) copyToDevice deviceApi func = do withFunc func $ \f -> [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Func* f)->copy_to_device(static_cast($(int api))); }); } |] pure func where api = fromIntegral . fromEnum $ deviceApi -- | Same as @'copyToDevice' 'DeviceHost'@ copyToHost :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Func t n a) copyToHost = copyToDevice DeviceHost -- | Split a dimension into inner and outer subdimensions with the given names, where the inner dimension -- iterates from @0@ to @factor-1@. -- -- The inner and outer subdimensions can then be dealt with using the other scheduling calls. It's okay -- to reuse the old variable name as either the inner or outer variable. The first argument specifies -- how the tail should be handled if the split factor does not provably divide the extent. -- split -- :: (KnownNat n, IsHalideType a) -- => TailStrategy -- -- ^ how to treat the remainder -- -> Func t n a -- -> Expr Int32 -- -- ^ loop variable to split -- -> Expr Int32 -- -- ^ new outer loop variable -- -> Expr Int32 -- -- ^ new inner loop variable -- -> Expr Int32 -- -- ^ split factor -- -> IO (Func t n a) -- split tail func old outer inner factor = do -- withFunc func $ \f -> -- asVarOrRVar old $ \old' -> -- asVarOrRVar outer $ \outer' -> -- asVarOrRVar inner $ \inner' -> -- asExpr factor $ \factor' -> -- [C.throwBlock| void { -- handle_halide_exceptions([=](){ -- $(Halide::Func* f)->split( -- *$(const Halide::VarOrRVar* old'), -- *$(const Halide::VarOrRVar* outer'), -- *$(const Halide::VarOrRVar* inner'), -- *$(const Halide::Expr* factor'), -- static_cast($(int t))); -- }); } |] -- pure func -- where -- t = fromIntegral . fromEnum $ tail -- | Join two dimensions into a single fused dimenion. -- -- The fused dimension covers the product of the extents of the inner and outer dimensions given. -- fuse -- :: (KnownNat n, IsHalideType a) -- => Func t n a -- -> Expr Int32 -- -- ^ inner loop variable -- -> Expr Int32 -- -- ^ outer loop variable -- -> Expr Int32 -- -- ^ new fused loop variable -- -> IO (Func t n a) -- fuse func outer inner fused = do -- withFunc func $ \f -> -- asVarOrRVar outer $ \outer' -> -- asVarOrRVar inner $ \inner' -> -- asVarOrRVar fused $ \fused' -> -- [CU.exp| void { -- $(Halide::Func* f)->fuse( -- *$(const Halide::VarOrRVar* outer'), -- *$(const Halide::VarOrRVar* inner'), -- *$(const Halide::VarOrRVar* fused')) } |] -- pure func -- withVarOrRVarMany :: [Expr Int32] -> (Int -> Ptr (CxxVector CxxVarOrRVar) -> IO a) -> IO a -- withVarOrRVarMany xs f = -- bracket allocate destroy $ \v -> do -- let go !k [] = f k v -- go !k (y : ys) = withVarOrRVarMany y $ \p -> do -- [CU.exp| void { $(std::vector* v)->push_back(*$(Halide::VarOrRVar* p)) } |] -- go (k + 1) ys -- go 0 xs -- where -- count = fromIntegral (length xs) -- withFunc func $ \f -> -- withVarOrRVarMany vars $ \count v -> do -- unless natVal (Proxy @n) -- handleHalideExceptionsM -- [C.tryBlock| void { -- $(Halide::Func* f)->reorder(*$(std::vector* v)); -- } |] -- -- class Curry (args :: [Type]) (r :: Type) (f :: Type) | args r -> f where -- curryG :: (Arguments args -> r) -> f mkBufferParameter :: forall n a. (KnownNat n, IsHalideType a) => Maybe Text -> IO (ForeignPtr CxxImageParam) mkBufferParameter maybeName = do with (halideTypeFor (Proxy @a)) $ \t -> do let d = fromIntegral $ natVal (Proxy @n) createWithoutName = [CU.exp| Halide::ImageParam* { new Halide::ImageParam{Halide::Type{*$(halide_type_t* t)}, $(int d)} } |] deleter = [C.funPtr| void deleteImageParam(Halide::ImageParam* p) { delete p; } |] createWithName name = let s = T.encodeUtf8 name in [CU.exp| Halide::ImageParam* { new Halide::ImageParam{ Halide::Type{*$(halide_type_t* t)}, $(int d), std::string{$bs-ptr:s, static_cast($bs-len:s)}} } |] newForeignPtr deleter =<< maybe createWithoutName createWithName maybeName getBufferParameter :: forall n a . (KnownNat n, IsHalideType a) => Maybe Text -> IORef (Maybe (ForeignPtr CxxImageParam)) -> IO (ForeignPtr CxxImageParam) getBufferParameter name r = readIORef r >>= \case Just fp -> pure fp Nothing -> do fp <- mkBufferParameter @n @a name writeIORef r (Just fp) pure fp -- | Same as 'withFunc', but ensures that we're dealing with 'Param' instead of a 'Func'. withBufferParam :: forall n a b . (HasCallStack, KnownNat n, IsHalideType a) => Func 'ParamTy n a -> (Ptr CxxImageParam -> IO b) -> IO b withBufferParam (Param r) action = getBufferParameter @n @a Nothing r >>= flip withForeignPtr action -- instance (KnownNat n, IsHalideType a) => Named (Func 'ParamTy n a) where -- setName :: Func 'ParamTy n a -> Text -> IO () -- setName (Param r) name = do -- readIORef r >>= \case -- Just _ -> error "the name of this Func has already been set" -- Nothing -> do -- fp <- mkBufferParameter @n @a (Just name) -- writeIORef r (Just fp) -- | Get the underlying pointer to @Halide::Func@ and invoke an 'IO' action with it. withFunc :: (KnownNat n, IsHalideType a) => Func t n a -> (Ptr CxxFunc -> IO b) -> IO b withFunc f = withForeignPtr (funcToForeignPtr f) wrapCxxFunc :: Ptr CxxFunc -> IO (Func 'FuncTy n a) wrapCxxFunc = fmap Func . newForeignPtr deleter where deleter = [C.funPtr| void deleteFunc(Halide::Func *x) { delete x; } |] forceFunc :: forall t n a. (KnownNat n, IsHalideType a) => Func t n a -> IO (Func 'FuncTy n a) forceFunc x@(Func _) = pure x forceFunc (Param r) = do fp <- getBufferParameter @n @a Nothing r withForeignPtr fp $ \p -> wrapCxxFunc =<< [CU.exp| Halide::Func* { new Halide::Func{static_cast(*$(Halide::ImageParam* p))} } |] funcToForeignPtr :: (KnownNat n, IsHalideType a) => Func t n a -> ForeignPtr CxxFunc funcToForeignPtr x = unsafePerformIO $! forceFunc x >>= \(Func fp) -> pure fp -- | Define a Halide function. -- -- @define "f" i e@ defines a Halide function called "f" such that @f[i] = e@. -- -- Here, @i@ is an @n@-element tuple of t'Var', i.e. the following are all valid: -- -- >>> [x, y, z] <- mapM mkVar ["x", "y", "z"] -- >>> f1 <- define "f1" x (0 :: Expr Float) -- >>> f2 <- define "f2" (x, y) (0 :: Expr Float) -- >>> f3 <- define "f3" (x, y, z) (0 :: Expr Float) define :: ( IsTuple (Arguments ts) i , All ((~) Var) ts , Length ts ~ n , KnownNat n , IsHalideType a ) => Text -> i -> Expr a -> IO (Func 'FuncTy n a) define name args expr = asVectorOf @((~) (Expr Int32)) asVar (fromTuple args) $ \x -> do let s = T.encodeUtf8 name asExpr expr $ \y -> wrapCxxFunc =<< [CU.block| Halide::Func* { Halide::Func f{std::string{$bs-ptr:s, static_cast($bs-len:s)}}; f(*$(std::vector* x)) = *$(Halide::Expr* y); return new Halide::Func{f}; } |] -- | Create an update definition for a Halide function. -- -- @update f i e@ creates an update definition for @f@ that performs @f[i] = e@. update :: ( IsTuple (Arguments ts) i , All ((~) (Expr Int32)) ts , Length ts ~ n , KnownNat n , IsHalideType a ) => Func 'FuncTy n a -> i -> Expr a -> IO () update func args expr = withFunc func $ \f -> asVectorOf @((~) (Expr Int32)) asExpr (fromTuple args) $ \x -> asExpr expr $ \y -> [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Func* f)->operator()(*$(std::vector* x)) = *$(Halide::Expr* y); }); } |] infix 9 ! -- | Apply a Halide function. Conceptually, @f ! i@ is equivalent to @f[i]@, i.e. -- indexing into a lazy array. (!) :: ( IsTuple (Arguments ts) i , All ((~) (Expr Int32)) ts , Length ts ~ n , KnownNat n , IsHalideType a ) => Func t n a -> i -> Expr a (!) func args = unsafePerformIO $ withFunc func $ \f -> asVectorOf @((~) (Expr Int32)) asExpr (fromTuple args) $ \x -> cxxConstructExpr $ \ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{ $(Halide::Func* f)->operator()(*$(std::vector* x))} } |] -- | Get a particular dimension of a pipeline parameter. dim :: forall n a . (HasCallStack, KnownNat n, IsHalideType a) => Int -> Func 'ParamTy n a -> IO Dimension dim k func | 0 <= k && k < fromIntegral (natVal (Proxy @n)) = let n = fromIntegral k in withBufferParam func $ \f -> wrapCxxDimension =<< [CU.exp| Halide::Internal::Dimension* { new Halide::Internal::Dimension{$(Halide::ImageParam* f)->dim($(int n))} } |] | otherwise = error $ "invalid dimension index: " <> show k <> "; Func is " <> show (natVal (Proxy @n)) <> "-dimensional" -- | Write out the loop nests specified by the schedule for this function. -- -- Helpful for understanding what a schedule is doing. -- -- For more info, see -- [@Halide::Func::print_loop_nest@](https://halide-lang.org/docs/class_halide_1_1_func.html#a03f839d9e13cae4b87a540aa618589ae) -- printLoopNest :: (KnownNat n, IsHalideType r) => Func n r -> IO () -- printLoopNest func = withFunc func $ \f -> -- [C.exp| void { $(Halide::Func* f)->print_loop_nest() } |] -- | Get the loop nests specified by the schedule for this function. -- -- Helpful for understanding what a schedule is doing. -- -- For more info, see -- [@Halide::Func::print_loop_nest@](https://halide-lang.org/docs/class_halide_1_1_func.html#a03f839d9e13cae4b87a540aa618589ae) prettyLoopNest :: (KnownNat n, IsHalideType r) => Func t n r -> IO Text prettyLoopNest func = withFunc func $ \f -> peekAndDeleteCxxString =<< [C.throwBlock| std::string* { return handle_halide_exceptions([=]() { return new std::string{Halide::Internal::print_loop_nest( std::vector{$(Halide::Func* f)->function()})}; }); } |] -- | Evaluate this function over a rectangular domain. realize :: forall n a t b . (KnownNat n, IsHalideType a) => Func t n a -- ^ Function to evaluate -> [Int] -- ^ Domain over which to evaluate -> (Ptr (HalideBuffer n a) -> IO b) -- ^ What to do with the buffer afterwards. Note that the buffer is allocated only temporary, -- so do not return it directly. -> IO b realize func shape action = withFunc func $ \f -> allocaCpuBuffer shape $ \buf -> do let raw = castPtr buf [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Func* f)->realize( Halide::Pipeline::RealizationArg{$(halide_buffer_t* raw)}); }); } |] action buf -- \| Evaluate this function over a one-dimensional domain and return the -- resulting buffer or buffers. -- realize1D -- :: forall a t -- . IsHalideType a -- => Int -- -- ^ @size@ of the domain. The function will be evaluated on @[0, ..., size -1]@ -- -> Func t 1 a -- -- ^ Function to evaluate -- -> IO (Vector a) -- realize1D size func = do -- buf <- SM.new size -- withHalideBuffer @1 @a buf $ \x -> do -- let b = castPtr x -- withFunc func $ \f -> -- [CU.exp| void { -- $(Halide::Func* f)->realize( -- Halide::Pipeline::RealizationArg{$(halide_buffer_t* b)}) } |] -- S.unsafeFreeze buf -- | A view pattern to specify the name of a buffer argument. -- -- Example usage: -- -- >>> :{ -- _ <- compile $ \(buffer "src" -> src) -> do -- i <- mkVar "i" -- define "dest" i $ (src ! i :: Expr Float) -- :} -- -- or if we want to specify the dimension and type, we can use type applications: -- -- >>> :{ -- _ <- compile $ \(buffer @1 @Float "src" -> src) -> do -- i <- mkVar "i" -- define "dest" i $ src ! i -- :} buffer :: forall n a. (KnownNat n, IsHalideType a) => Text -> Func 'ParamTy n a -> Func 'ParamTy n a buffer name p@(Param r) = unsafePerformIO $ do _ <- getBufferParameter @n @a (Just name) r pure p -- | Similar to 'buffer', but for scalar parameters. -- -- Example usage: -- -- >>> :{ -- _ <- compile $ \(scalar @Float "a" -> a) -> do -- i <- mkVar "i" -- define "dest" i $ a -- :} scalar :: forall a. IsHalideType a => Text -> Expr a -> Expr a scalar name (ScalarParam r) = unsafePerformIO $ do readIORef r >>= \case Just _ -> error "the name of this Expr has already been set" Nothing -> do fp <- mkScalarParameter @a (Just name) writeIORef r (Just fp) pure (ScalarParam r) scalar _ _ = error "cannot set the name of an expression that is not a parameter" wrapCxxStage :: (KnownNat n, IsHalideType a) => Ptr CxxStage -> IO (Stage n a) wrapCxxStage = fmap Stage . newForeignPtr deleter where deleter = [C.funPtr| void deleteStage(Halide::Stage* p) { delete p; } |] withCxxStage :: (KnownNat n, IsHalideType a) => Stage n a -> (Ptr CxxStage -> IO b) -> IO b withCxxStage (Stage fp) = withForeignPtr fp -- | Get the pure stage of a 'Func' for the purposes of scheduling it. getStage :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Stage n a) getStage func = withFunc func $ \func' -> [CU.exp| Halide::Stage* { new Halide::Stage{static_cast(*$(Halide::Func* func'))} } |] >>= wrapCxxStage -- | Return 'True' when the function has update definitions, 'False' otherwise. hasUpdateDefinitions :: (KnownNat n, IsHalideType a) => Func t n a -> IO Bool hasUpdateDefinitions func = withFunc func $ \func' -> toBool <$> [CU.exp| bool { $(const Halide::Func* func')->has_update_definition() } |] -- | Get a handle to an update step for the purposes of scheduling it. getUpdateStage :: (KnownNat n, IsHalideType a) => Int -> Func 'FuncTy n a -> IO (Stage n a) getUpdateStage k func = withFunc func $ \func' -> let k' = fromIntegral k in [CU.exp| Halide::Stage* { new Halide::Stage{$(Halide::Func* func')->update($(int k'))} } |] >>= wrapCxxStage -- | Identify the loop nest corresponding to some dimension of some function. getLoopLevelAtStage :: (KnownNat n, IsHalideType a) => Func t n a -> Expr Int32 -> Int -- ^ update index -> IO (LoopLevel 'LockedTy) getLoopLevelAtStage func var stageIndex = withFunc func $ \f -> asVarOrRVar var $ \i -> do (SomeLoopLevel level) <- wrapCxxLoopLevel =<< [C.throwBlock| Halide::LoopLevel* { return handle_halide_exceptions([=](){ return new Halide::LoopLevel{*$(const Halide::Func* f), *$(const Halide::VarOrRVar* i), $(int k)}; }); } |] case level of LoopLevel _ -> pure level _ -> error $ "getLoopLevelAtStage: got " <> show level <> ", but expected a LoopLevel 'LockedTy" where k = fromIntegral stageIndex -- | Same as 'getLoopLevelAtStage' except that the stage is @-1@. getLoopLevel :: (KnownNat n, IsHalideType a) => Func t n a -> Expr Int32 -> IO (LoopLevel 'LockedTy) getLoopLevel f i = getLoopLevelAtStage f i (-1) -- | Allocate storage for this function within a particular loop level. -- -- Scheduling storage is optional, and can be used to separate the loop level at which storage is allocated -- from the loop level at which computation occurs to trade off between locality and redundant work. -- -- For more info, see [Halide::Func::store_at](https://halide-lang.org/docs/class_halide_1_1_func.html#a417c08f8aa3a5cdf9146fba948b65193). storeAt :: (KnownNat n, IsHalideType a) => Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a) storeAt func level = do withFunc func $ \f -> withCxxLoopLevel level $ \l -> [CU.exp| void { $(Halide::Func* f)->store_at(*$(const Halide::LoopLevel* l)) } |] pure func -- | Schedule a function to be computed within the iteration over a given loop level. -- -- For more info, see [Halide::Func::compute_at](https://halide-lang.org/docs/class_halide_1_1_func.html#a800cbcc3ca5e3d3fa1707f6e1990ec83). computeAt :: (KnownNat n, IsHalideType a) => Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a) computeAt func level = do withFunc func $ \f -> withCxxLoopLevel level $ \l -> [CU.exp| void { $(Halide::Func* f)->compute_at(*$(const Halide::LoopLevel* l)) } |] pure func -- | Wrap a buffer into a t'Func'. -- -- Suppose, we are defining a pipeline that adds together two vectors, and we'd like to call 'realize' to -- evaluate it directly, how do we pass the vectors to the t'Func'? 'asBufferParam' allows to do exactly this. -- -- > asBuffer [1, 2, 3] $ \a -> -- > asBuffer [4, 5, 6] $ \b -> do -- > i <- mkVar "i" -- > f <- define "vectorAdd" i $ a ! i + b ! i -- > realize f [3] $ \result -> -- > print =<< peekToList f asBufferParam :: forall n a t b . IsHalideBuffer t n a => t -- ^ Object to treat as a buffer -> (Func 'ParamTy n a -> IO b) -- ^ What to do with the __temporary__ buffer -> IO b asBufferParam arr action = withHalideBuffer @n @a arr $ \arr' -> do param <- mkBufferParameter @n @a Nothing withForeignPtr param $ \param' -> let buf = (castPtr arr' :: Ptr RawHalideBuffer) in [CU.block| void { $(Halide::ImageParam* param')->set(Halide::Buffer<>{*$(const halide_buffer_t* buf)}); } |] action . Param =<< newIORef (Just param)