{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
module Language.Halide.Buffer
(
HalideBuffer (..)
, allocaCpuBuffer
, allocaBuffer
, IsListPeek (..)
, peekScalar
, IsHalideBuffer (..)
, withHalideBuffer
, bufferFromPtrShapeStrides
, bufferFromPtrShape
, RawHalideBuffer (..)
, HalideDimension (..)
, HalideDeviceInterface
, rowMajorStrides
, colMajorStrides
, isDeviceDirty
, isHostDirty
, getBufferExtent
, bufferCopyToHost
, withCopiedToHost
, withCropped
)
where
import Control.Exception (bracket_)
import Control.Monad (forM, unless, when)
import Control.Monad.ST (RealWorld)
import Data.Int
import Data.Kind (Type)
import Data.List qualified as List
import Data.Proxy
import Data.Vector.Storable qualified as S
import Data.Vector.Storable.Mutable qualified as SM
import Data.Word
import Foreign.Marshal.Alloc (alloca, free, mallocBytes)
import Foreign.Marshal.Array
import Foreign.Marshal.Utils
import Foreign.Ptr
import Foreign.Storable
import GHC.Stack (HasCallStack)
import GHC.TypeNats
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Context
import Language.Halide.Target
import Language.Halide.Type
import Prelude hiding (min)
data HalideDimension = HalideDimension
{ HalideDimension -> Int32
halideDimensionMin :: {-# UNPACK #-} !Int32
, HalideDimension -> Int32
halideDimensionExtent :: {-# UNPACK #-} !Int32
, HalideDimension -> Int32
halideDimensionStride :: {-# UNPACK #-} !Int32
, HalideDimension -> Word32
halideDimensionFlags :: {-# UNPACK #-} !Word32
}
deriving stock (ReadPrec [HalideDimension]
ReadPrec HalideDimension
Int -> ReadS HalideDimension
ReadS [HalideDimension]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HalideDimension]
$creadListPrec :: ReadPrec [HalideDimension]
readPrec :: ReadPrec HalideDimension
$creadPrec :: ReadPrec HalideDimension
readList :: ReadS [HalideDimension]
$creadList :: ReadS [HalideDimension]
readsPrec :: Int -> ReadS HalideDimension
$creadsPrec :: Int -> ReadS HalideDimension
Read, Int -> HalideDimension -> ShowS
[HalideDimension] -> ShowS
HalideDimension -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HalideDimension] -> ShowS
$cshowList :: [HalideDimension] -> ShowS
show :: HalideDimension -> String
$cshow :: HalideDimension -> String
showsPrec :: Int -> HalideDimension -> ShowS
$cshowsPrec :: Int -> HalideDimension -> ShowS
Show, HalideDimension -> HalideDimension -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideDimension -> HalideDimension -> Bool
$c/= :: HalideDimension -> HalideDimension -> Bool
== :: HalideDimension -> HalideDimension -> Bool
$c== :: HalideDimension -> HalideDimension -> Bool
Eq)
instance Storable HalideDimension where
sizeOf :: HalideDimension -> Int
sizeOf HalideDimension
_ = Int
16
{-# INLINE sizeOf #-}
alignment :: HalideDimension -> Int
alignment HalideDimension
_ = Int
4
{-# INLINE alignment #-}
peek :: Ptr HalideDimension -> IO HalideDimension
peek Ptr HalideDimension
p =
Int32 -> Int32 -> Int32 -> Word32 -> HalideDimension
HalideDimension
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
0
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
4
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
8
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
12
{-# INLINE peek #-}
poke :: Ptr HalideDimension -> HalideDimension -> IO ()
poke Ptr HalideDimension
p HalideDimension
x = do
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
0 (HalideDimension -> Int32
halideDimensionMin HalideDimension
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
4 (HalideDimension -> Int32
halideDimensionExtent HalideDimension
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
8 (HalideDimension -> Int32
halideDimensionStride HalideDimension
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
12 (HalideDimension -> Word32
halideDimensionFlags HalideDimension
x)
{-# INLINE poke #-}
simpleDimension :: Int -> Int -> HalideDimension
simpleDimension :: Int -> Int -> HalideDimension
simpleDimension Int
extent Int
stride = Int32 -> Int32 -> Int32 -> Word32 -> HalideDimension
HalideDimension Int32
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
extent) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
stride) Word32
0
{-# INLINE simpleDimension #-}
rowMajorStrides
:: Integral a
=> [a]
-> [a]
rowMajorStrides :: forall a. Integral a => [a] -> [a]
rowMajorStrides = forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Num a => a -> a -> a
(*) a
1
colMajorStrides
:: Integral a
=> [a]
-> [a]
colMajorStrides :: forall a. Integral a => [a] -> [a]
colMajorStrides = forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) a
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
init
data HalideDeviceInterface
data RawHalideBuffer = RawHalideBuffer
{ RawHalideBuffer -> Word64
halideBufferDevice :: !Word64
, RawHalideBuffer -> Ptr HalideDeviceInterface
halideBufferDeviceInterface :: !(Ptr HalideDeviceInterface)
, RawHalideBuffer -> Ptr Word8
halideBufferHost :: !(Ptr Word8)
, RawHalideBuffer -> Word64
halideBufferFlags :: !Word64
, RawHalideBuffer -> HalideType
halideBufferType :: !HalideType
, RawHalideBuffer -> Int32
halideBufferDimensions :: !Int32
, RawHalideBuffer -> Ptr HalideDimension
halideBufferDim :: !(Ptr HalideDimension)
, RawHalideBuffer -> Ptr ()
halideBufferPadding :: !(Ptr ())
}
deriving stock (Int -> RawHalideBuffer -> ShowS
[RawHalideBuffer] -> ShowS
RawHalideBuffer -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RawHalideBuffer] -> ShowS
$cshowList :: [RawHalideBuffer] -> ShowS
show :: RawHalideBuffer -> String
$cshow :: RawHalideBuffer -> String
showsPrec :: Int -> RawHalideBuffer -> ShowS
$cshowsPrec :: Int -> RawHalideBuffer -> ShowS
Show, RawHalideBuffer -> RawHalideBuffer -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RawHalideBuffer -> RawHalideBuffer -> Bool
$c/= :: RawHalideBuffer -> RawHalideBuffer -> Bool
== :: RawHalideBuffer -> RawHalideBuffer -> Bool
$c== :: RawHalideBuffer -> RawHalideBuffer -> Bool
Eq)
newtype HalideBuffer (n :: Nat) (a :: Type) = HalideBuffer {forall (n :: Nat) a. HalideBuffer n a -> RawHalideBuffer
unHalideBuffer :: RawHalideBuffer}
deriving stock (Int -> HalideBuffer n a -> ShowS
forall (n :: Nat) a. Int -> HalideBuffer n a -> ShowS
forall (n :: Nat) a. [HalideBuffer n a] -> ShowS
forall (n :: Nat) a. HalideBuffer n a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HalideBuffer n a] -> ShowS
$cshowList :: forall (n :: Nat) a. [HalideBuffer n a] -> ShowS
show :: HalideBuffer n a -> String
$cshow :: forall (n :: Nat) a. HalideBuffer n a -> String
showsPrec :: Int -> HalideBuffer n a -> ShowS
$cshowsPrec :: forall (n :: Nat) a. Int -> HalideBuffer n a -> ShowS
Show, HalideBuffer n a -> HalideBuffer n a -> Bool
forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideBuffer n a -> HalideBuffer n a -> Bool
$c/= :: forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
== :: HalideBuffer n a -> HalideBuffer n a -> Bool
$c== :: forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
Eq)
importHalide
instance Storable RawHalideBuffer where
sizeOf :: RawHalideBuffer -> Int
sizeOf RawHalideBuffer
_ = Int
56
alignment :: RawHalideBuffer -> Int
alignment RawHalideBuffer
_ = Int
8
peek :: Ptr RawHalideBuffer -> IO RawHalideBuffer
peek Ptr RawHalideBuffer
p =
Word64
-> Ptr HalideDeviceInterface
-> Ptr Word8
-> Word64
-> HalideType
-> Int32
-> Ptr HalideDimension
-> Ptr ()
-> RawHalideBuffer
RawHalideBuffer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
0
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
8
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
16
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
24
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
32
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
36
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
40
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
48
poke :: Ptr RawHalideBuffer -> RawHalideBuffer -> IO ()
poke Ptr RawHalideBuffer
p RawHalideBuffer
x = do
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
0 (RawHalideBuffer -> Word64
halideBufferDevice RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
8 (RawHalideBuffer -> Ptr HalideDeviceInterface
halideBufferDeviceInterface RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
16 (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
24 (RawHalideBuffer -> Word64
halideBufferFlags RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
32 (RawHalideBuffer -> HalideType
halideBufferType RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
36 (RawHalideBuffer -> Int32
halideBufferDimensions RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
40 (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
48 (RawHalideBuffer -> Ptr ()
halideBufferPadding RawHalideBuffer
x)
bufferFromPtrShapeStrides
:: forall n a b
. (HasCallStack, KnownNat n, IsHalideType a)
=> Ptr a
-> [Int]
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
bufferFromPtrShapeStrides :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShapeStrides Ptr a
p [Int]
shape [Int]
stride Ptr (HalideBuffer n a) -> IO b
action =
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> HalideDimension
simpleDimension [Int]
shape [Int]
stride) forall a b. (a -> b) -> a -> b
$ \Int
n Ptr HalideDimension
dim -> do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
n forall a. Eq a => a -> a -> Bool
== forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"specified wrong number of dimensions: "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n
forall a. Semigroup a => a -> a -> a
<> String
"; expected "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
forall a. Semigroup a => a -> a -> a
<> String
" from the type declaration"
let !buffer :: RawHalideBuffer
buffer =
RawHalideBuffer
{ halideBufferDevice :: Word64
halideBufferDevice = Word64
0
, halideBufferDeviceInterface :: Ptr HalideDeviceInterface
halideBufferDeviceInterface = forall a. Ptr a
nullPtr
, halideBufferHost :: Ptr Word8
halideBufferHost = forall a b. Ptr a -> Ptr b
castPtr Ptr a
p
, halideBufferFlags :: Word64
halideBufferFlags = Word64
0
, halideBufferType :: HalideType
halideBufferType = forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
, halideBufferDimensions :: Int32
halideBufferDimensions = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
, halideBufferDim :: Ptr HalideDimension
halideBufferDim = Ptr HalideDimension
dim
, halideBufferPadding :: Ptr ()
halideBufferPadding = forall a. Ptr a
nullPtr
}
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with RawHalideBuffer
buffer forall a b. (a -> b) -> a -> b
$ \Ptr RawHalideBuffer
bufferPtr -> do
b
r <- Ptr (HalideBuffer n a) -> IO b
action (forall a b. Ptr a -> Ptr b
castPtr Ptr RawHalideBuffer
bufferPtr)
Bool
hasDataOnDevice <-
forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(halide_buffer_t* bufferPtr)->device } |]
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
hasDataOnDevice forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"the Buffer still references data on the device; did you forget to call copyToHost?"
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
r
bufferFromPtrShape
:: (HasCallStack, KnownNat n, IsHalideType a)
=> Ptr a
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
bufferFromPtrShape :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
p [Int]
shape = forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShapeStrides Ptr a
p [Int]
shape (forall a. Integral a => [a] -> [a]
colMajorStrides [Int]
shape)
class (KnownNat n, IsHalideType a) => IsHalideBuffer t n a where
withHalideBufferImpl :: t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer :: forall n a t b. IsHalideBuffer t n a => t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer :: forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer = forall t (n :: Nat) a b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBufferImpl @t @n @a
instance IsHalideType a => IsHalideBuffer (S.Vector a) 1 a where
withHalideBufferImpl :: forall b. Vector a -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl Vector a
v Ptr (HalideBuffer 1 a) -> IO b
f =
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
dataPtr ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
dataPtr [forall a. Storable a => Vector a -> Int
S.length Vector a
v] Ptr (HalideBuffer 1 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer (S.MVector RealWorld a) 1 a where
withHalideBufferImpl :: forall b.
MVector RealWorld a -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl MVector RealWorld a
v Ptr (HalideBuffer 1 a) -> IO b
f =
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
SM.unsafeWith MVector RealWorld a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
dataPtr ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
dataPtr [forall a s. Storable a => MVector s a -> Int
SM.length MVector RealWorld a
v] Ptr (HalideBuffer 1 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer [a] 1 a where
withHalideBufferImpl :: forall b. [a] -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl [a]
v = forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer (forall a. Storable a => [a] -> Vector a
S.fromList [a]
v)
instance IsHalideType a => IsHalideBuffer [[a]] 2 a where
withHalideBufferImpl :: forall b. [[a]] -> (Ptr (HalideBuffer 2 a) -> IO b) -> IO b
withHalideBufferImpl [[a]]
xs Ptr (HalideBuffer 2 a) -> IO b
f = do
let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[a]]
xs
d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[a]]
xs)
v :: Vector a
v = forall a. Storable a => [a] -> Vector a
S.fromList (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
List.concat (forall a. [[a]] -> [[a]]
List.transpose [[a]]
xs))
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1] Ptr (HalideBuffer 2 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer [[[a]]] 3 a where
withHalideBufferImpl :: forall b. [[[a]]] -> (Ptr (HalideBuffer 3 a) -> IO b) -> IO b
withHalideBufferImpl [[[a]]]
xs Ptr (HalideBuffer 3 a) -> IO b
f = do
let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[[a]]]
xs
d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[[a]]]
xs)
d2 :: Int
d2 = if Int
d1 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head (forall a. [a] -> a
head [[[a]]]
xs))
v :: Vector a
v =
forall a. Storable a => [a] -> Vector a
S.fromList
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
List.concat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
List.concatMap forall a. [[a]] -> [[a]]
List.transpose
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
List.transpose
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose
forall a b. (a -> b) -> a -> b
$ [[[a]]]
xs
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1 forall a. Num a => a -> a -> a
* Int
d2) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1, Int
d2] Ptr (HalideBuffer 3 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer [[[[a]]]] 4 a where
withHalideBufferImpl :: forall b. [[[[a]]]] -> (Ptr (HalideBuffer 4 a) -> IO b) -> IO b
withHalideBufferImpl [[[[a]]]]
xs Ptr (HalideBuffer 4 a) -> IO b
f = do
let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[[[a]]]]
xs
d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[[[a]]]]
xs)
d2 :: Int
d2 = if Int
d1 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head (forall a. [a] -> a
head [[[[a]]]]
xs))
d3 :: Int
d3 = if Int
d2 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head (forall a. [a] -> a
head (forall a. [a] -> a
head [[[[a]]]]
xs)))
v :: Vector a
v =
forall a. Storable a => [a] -> Vector a
S.fromList
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
List.transpose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
List.transpose
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. [[a]] -> [[a]]
List.transpose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose)
forall a b. (a -> b) -> a -> b
$ [[[[a]]]]
xs
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1 forall a. Num a => a -> a -> a
* Int
d2 forall a. Num a => a -> a -> a
* Int
d3) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1, Int
d2, Int
d3] Ptr (HalideBuffer 4 a) -> IO b
f
whenM :: Monad m => m Bool -> m () -> m ()
whenM :: forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM m Bool
cond m ()
f =
m Bool
cond forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
True -> m ()
f
Bool
False -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
allocaCpuBuffer
:: forall n a b
. (HasCallStack, KnownNat n, IsHalideType a)
=> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
allocaCpuBuffer :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
[Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaCpuBuffer = forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Target -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaBuffer Target
hostTarget
getTotalBytes :: Ptr RawHalideBuffer -> IO Int
getTotalBytes :: Ptr RawHalideBuffer -> IO Int
getTotalBytes Ptr RawHalideBuffer
buf = do
forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.block| size_t {
auto const& b = *$(const halide_buffer_t* buf);
auto const n = std::accumulate(b.dim, b.dim + b.dimensions, size_t{1},
[](auto acc, auto const& dim) { return acc * dim.extent; });
return n * (b.type.bits * b.type.lanes / 8);
} |]
allocateHostMemory :: Ptr RawHalideBuffer -> IO ()
allocateHostMemory :: Ptr RawHalideBuffer -> IO ()
allocateHostMemory Ptr RawHalideBuffer
buf = do
Ptr Word8
ptr <- forall a. Int -> IO (Ptr a)
mallocBytes forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr RawHalideBuffer -> IO Int
getTotalBytes Ptr RawHalideBuffer
buf
[CU.block| void { $(halide_buffer_t* buf)->host = $(uint8_t* ptr); } |]
freeHostMemory :: Ptr RawHalideBuffer -> IO ()
freeHostMemory :: Ptr RawHalideBuffer -> IO ()
freeHostMemory Ptr RawHalideBuffer
buf = do
Ptr Word8
ptr <-
[CU.block| uint8_t* {
auto& b = *$(halide_buffer_t* buf);
auto const p = b.host;
b.host = nullptr;
return p;
} |]
forall a. Ptr a -> IO ()
free Ptr Word8
ptr
allocateDeviceMemory :: Ptr HalideDeviceInterface -> Ptr RawHalideBuffer -> IO ()
allocateDeviceMemory :: Ptr HalideDeviceInterface -> Ptr RawHalideBuffer -> IO ()
allocateDeviceMemory Ptr HalideDeviceInterface
interface Ptr RawHalideBuffer
buf = do
[CU.block| void {
auto const* interface = $(const halide_device_interface_t* interface);
interface->device_malloc(nullptr, $(halide_buffer_t* buf), interface);
} |]
freeDeviceMemory :: HasCallStack => Ptr RawHalideBuffer -> IO ()
freeDeviceMemory :: HasCallStack => Ptr RawHalideBuffer -> IO ()
freeDeviceMemory Ptr RawHalideBuffer
buf = do
Ptr HalideDeviceInterface
deviceInterface <-
[CU.exp| const halide_device_interface_t* { $(const halide_buffer_t* buf)->device_interface } |]
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr HalideDeviceInterface
deviceInterface forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"cannot free device memory: device_interface is NULL"
[CU.block| void {
$(halide_buffer_t* buf)->device_interface->device_free(nullptr, $(halide_buffer_t* buf));
$(halide_buffer_t* buf)->device = 0;
} |]
allocaBuffer
:: forall n a b
. (HasCallStack, KnownNat n, IsHalideType a)
=> Target
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
allocaBuffer :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Target -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaBuffer Target
target [Int]
shape Ptr (HalideBuffer n a) -> IO b
action = do
Ptr HalideDeviceInterface
deviceInterface <- Target -> IO (Ptr HalideDeviceInterface)
getDeviceInterface Target
target
let onHost :: Bool
onHost = Ptr HalideDeviceInterface
deviceInterface forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> HalideDimension
simpleDimension [Int]
shape (forall a. Integral a => [a] -> [a]
colMajorStrides [Int]
shape)) forall a b. (a -> b) -> a -> b
$ \Int
n Ptr HalideDimension
dim -> do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
n forall a. Eq a => a -> a -> Bool
== forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"specified wrong number of dimensions: "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n
forall a. Semigroup a => a -> a -> a
<> String
"; expected "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
forall a. Semigroup a => a -> a -> a
<> String
" from the type declaration"
let rawBuffer :: RawHalideBuffer
rawBuffer =
RawHalideBuffer
{ halideBufferDevice :: Word64
halideBufferDevice = Word64
0
, halideBufferDeviceInterface :: Ptr HalideDeviceInterface
halideBufferDeviceInterface = forall a. Ptr a
nullPtr
, halideBufferHost :: Ptr Word8
halideBufferHost = forall a. Ptr a
nullPtr
, halideBufferFlags :: Word64
halideBufferFlags = Word64
0
, halideBufferType :: HalideType
halideBufferType = forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
, halideBufferDimensions :: Int32
halideBufferDimensions = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
, halideBufferDim :: Ptr HalideDimension
halideBufferDim = Ptr HalideDimension
dim
, halideBufferPadding :: Ptr ()
halideBufferPadding = forall a. Ptr a
nullPtr
}
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with RawHalideBuffer
rawBuffer forall a b. (a -> b) -> a -> b
$ \Ptr RawHalideBuffer
buf -> do
let allocate :: Ptr RawHalideBuffer -> IO ()
allocate
| Bool
onHost = Ptr RawHalideBuffer -> IO ()
allocateHostMemory
| Bool
otherwise = Ptr HalideDeviceInterface -> Ptr RawHalideBuffer -> IO ()
allocateDeviceMemory Ptr HalideDeviceInterface
deviceInterface
let deallocate :: Ptr RawHalideBuffer -> IO ()
deallocate
| Bool
onHost = Ptr RawHalideBuffer -> IO ()
freeHostMemory
| Bool
otherwise = HasCallStack => Ptr RawHalideBuffer -> IO ()
freeDeviceMemory
forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ (Ptr RawHalideBuffer -> IO ()
allocate Ptr RawHalideBuffer
buf) (Ptr RawHalideBuffer -> IO ()
deallocate Ptr RawHalideBuffer
buf) forall a b. (a -> b) -> a -> b
$ do
b
r <- Ptr (HalideBuffer n a) -> IO b
action (forall a b. Ptr a -> Ptr b
castPtr Ptr RawHalideBuffer
buf)
Bool
isHostNull <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(halide_buffer_t* buf)->host == nullptr } |]
Bool
isDeviceNull <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(halide_buffer_t* buf)->device == 0 } |]
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
onHost Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
isDeviceNull) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"buffer was allocated on host, but its device pointer is not NULL"
forall a. Semigroup a => a -> a -> a
<> String
"; did you forget a copyToHost in your pipeline?"
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
onHost Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
isHostNull) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"buffer was allocated on device, but its host pointer is not NULL"
forall a. Semigroup a => a -> a -> a
<> String
"; did you add an extra copyToHost?"
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
r
getDeviceInterface :: Target -> IO (Ptr HalideDeviceInterface)
getDeviceInterface :: Target -> IO (Ptr HalideDeviceInterface)
getDeviceInterface Target
target =
case DeviceAPI
device of
DeviceAPI
DeviceNone -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Ptr a
nullPtr
DeviceAPI
DeviceHost -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Ptr a
nullPtr
DeviceAPI
_ ->
forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
[C.throwBlock| const halide_device_interface_t* {
return handle_halide_exceptions([=](){
auto const device = static_cast<Halide::DeviceAPI>($(int api));
auto const& target = *$(const Halide::Target* target');
return Halide::get_device_interface_for_device_api(device, target, "getDeviceInterface");
});
} |]
where
device :: DeviceAPI
device@(forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) = Target -> DeviceAPI
deviceAPIForTarget Target
target
isDeviceDirty :: Ptr RawHalideBuffer -> IO Bool
isDeviceDirty :: Ptr RawHalideBuffer -> IO Bool
isDeviceDirty Ptr RawHalideBuffer
p =
forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const halide_buffer_t* p)->device_dirty() } |]
setDeviceDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setDeviceDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setDeviceDirty (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CBool
b) Ptr RawHalideBuffer
p =
[CU.exp| void { $(halide_buffer_t* p)->set_device_dirty($(bool b)) } |]
isHostDirty :: Ptr RawHalideBuffer -> IO Bool
isHostDirty :: Ptr RawHalideBuffer -> IO Bool
isHostDirty Ptr RawHalideBuffer
p =
forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const halide_buffer_t* p)->host_dirty() } |]
setHostDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setHostDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setHostDirty (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CBool
b) Ptr RawHalideBuffer
p =
[CU.exp| void { $(halide_buffer_t* p)->set_host_dirty($(bool b)) } |]
bufferCopyToHost :: HasCallStack => Ptr RawHalideBuffer -> IO ()
bufferCopyToHost :: HasCallStack => Ptr RawHalideBuffer -> IO ()
bufferCopyToHost Ptr RawHalideBuffer
p = forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty Ptr RawHalideBuffer
p) forall a b. (a -> b) -> a -> b
$ do
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek Ptr RawHalideBuffer
p
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDeviceInterface forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"device_dirty is set, but device_interface is NULL"
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferHost forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"host is NULL, did you forget to allocate memory?"
[CU.block| void {
auto& buf = *$(halide_buffer_t* p);
buf.device_interface->copy_to_host(nullptr, &buf);
} |]
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty Ptr RawHalideBuffer
p) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"device_dirty is set right after a copy_to_host; something went wrong..."
checkNumberOfDimensions :: forall n. (HasCallStack, KnownNat n) => RawHalideBuffer -> IO ()
checkNumberOfDimensions :: forall (n :: Nat).
(HasCallStack, KnownNat n) =>
RawHalideBuffer -> IO ()
checkNumberOfDimensions RawHalideBuffer
raw = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) forall a. Eq a => a -> a -> Bool
== RawHalideBuffer
raw.halideBufferDimensions) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"type-level and runtime number of dimensions do not match: "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
forall a. Semigroup a => a -> a -> a
<> String
" != "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show RawHalideBuffer
raw.halideBufferDimensions
withCropped
:: Ptr (HalideBuffer n a)
-> Int
-> Int
-> Int
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
withCropped :: forall (n :: Nat) a b.
Ptr (HalideBuffer n a)
-> Int -> Int -> Int -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withCropped
(forall a b. Ptr a -> Ptr b
castPtr -> Ptr RawHalideBuffer
src)
(forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
d)
(forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
min)
(forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
extent)
Ptr (HalideBuffer n a) -> IO b
action = do
Int
rank <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { $(const halide_buffer_t* src)->dimensions } |]
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr RawHalideBuffer
dst ->
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
rank forall a b. (a -> b) -> a -> b
$ \Ptr HalideDimension
dstDim -> do
[CU.block| void {
auto const& src = *$(const halide_buffer_t* src);
auto& dst = *$(halide_buffer_t* dst);
auto const d = $(int d);
dst = src;
dst.dim = $(halide_dimension_t* dstDim);
memcpy(dst.dim, src.dim, src.dimensions * sizeof(halide_dimension_t));
if (dst.host != nullptr) {
auto const shift = $(int min) - src.dim[d].min;
dst.host += (shift * src.dim[d].stride) * ((src.type.bits + 7) / 8);
}
dst.dim[d].min = $(int min);
dst.dim[d].extent = $(int extent);
if (src.device != 0 && src.device_interface != nullptr) {
src.device_interface->device_crop(nullptr, &src, &dst);
}
} |]
Ptr (HalideBuffer n a) -> IO b
action (forall a b. Ptr a -> Ptr b
castPtr Ptr RawHalideBuffer
dst)
getBufferExtent :: forall n a. KnownNat n => Ptr (HalideBuffer n a) -> Int -> IO Int
getBufferExtent :: forall (n :: Nat) a.
KnownNat n =>
Ptr (HalideBuffer n a) -> Int -> IO Int
getBufferExtent (forall a b. Ptr a -> Ptr b
castPtr -> Ptr RawHalideBuffer
buf) (forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
d)
| CInt
d forall a. Ord a => a -> a -> Bool
< forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) =
forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { $(const halide_buffer_t* buf)->dim[$(int d)].extent } |]
| Bool
otherwise = forall a. HasCallStack => String -> a
error String
"index out of bounds"
peekScalar :: forall a. (HasCallStack, IsHalideType a) => Ptr (HalideBuffer 0 a) -> IO a
peekScalar :: forall a.
(HasCallStack, IsHalideType a) =>
Ptr (HalideBuffer 0 a) -> IO a
peekScalar Ptr (HalideBuffer 0 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 0 a)
p forall a b. (a -> b) -> a -> b
$ do
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 0 a)
p)
forall (n :: Nat).
(HasCallStack, KnownNat n) =>
RawHalideBuffer -> IO ()
checkNumberOfDimensions @0 RawHalideBuffer
raw
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferHost forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
forall a. Storable a => Ptr a -> IO a
peek forall a b. (a -> b) -> a -> b
$ forall a b. Ptr a -> Ptr b
castPtr @_ @a RawHalideBuffer
raw.halideBufferHost
type family NestedList (n :: Nat) (a :: Type) where
NestedList 0 a = a
NestedList 1 a = [a]
NestedList 2 a = [[a]]
NestedList 3 a = [[[a]]]
NestedList 4 a = [[[[a]]]]
NestedList 5 a = [[[[[a]]]]]
type family NestedListLevel (a :: Type) :: Nat where
NestedListLevel [a] = 1 + NestedListLevel a
NestedListLevel a = 0
type family NestedListType (a :: Type) :: Type where
NestedListType [a] = NestedListType a
NestedListType a = a
class
( KnownNat n
, IsHalideType a
, NestedList n a ~ b
, NestedListLevel b ~ n
, NestedListType b ~ a
) =>
IsListPeek n a b
| n a -> b
, n b -> a
, a b -> n
where
peekToList :: HasCallStack => Ptr (HalideBuffer n a) -> IO b
instance
(IsHalideType a, NestedListLevel [a] ~ 1, NestedListType [a] ~ a)
=> IsListPeek 1 a [a]
where
peekToList :: HasCallStack => Ptr (HalideBuffer 1 a) -> IO [a]
peekToList Ptr (HalideBuffer 1 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 1 a)
p forall a b. (a -> b) -> a -> b
$ do
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 1 a)
p)
(HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 ->
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0))
instance
(IsHalideType a, NestedListLevel [[a]] ~ 2, NestedListType [[a]] ~ a)
=> IsListPeek 2 a [[a]]
where
peekToList :: HasCallStack => Ptr (HalideBuffer 2 a) -> IO [[a]]
peekToList Ptr (HalideBuffer 2 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 2 a)
p forall a b. (a -> b) -> a -> b
$ do
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 2 a)
p)
(HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
(HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 ->
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr1 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1))
instance
(IsHalideType a, NestedListLevel [[[a]]] ~ 3, NestedListType [[[a]]] ~ a)
=> IsListPeek 3 a [[[a]]]
where
peekToList :: HasCallStack => Ptr (HalideBuffer 3 a) -> IO [[[a]]]
peekToList Ptr (HalideBuffer 3 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 3 a)
p forall a b. (a -> b) -> a -> b
$ do
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 3 a)
p)
(HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
(HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
(HalideDimension Int32
min2 Int32
extent2 Int32
stride2 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
2
let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 -> do
let ptr2 :: Ptr a
ptr2 = Ptr a
ptr1 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent2 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i2 ->
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr2 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min2 forall a. Num a => a -> a -> a
+ Int32
stride2 forall a. Num a => a -> a -> a
* Int32
i2))
instance
(IsHalideType a, NestedListLevel [[[[a]]]] ~ 4, NestedListType [[[[a]]]] ~ a)
=> IsListPeek 4 a [[[[a]]]]
where
peekToList :: HasCallStack => Ptr (HalideBuffer 4 a) -> IO [[[[a]]]]
peekToList Ptr (HalideBuffer 4 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 4 a)
p forall a b. (a -> b) -> a -> b
$ do
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 4 a)
p)
(HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
(HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
(HalideDimension Int32
min2 Int32
extent2 Int32
stride2 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
2
(HalideDimension Int32
min3 Int32
extent3 Int32
stride3 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
3
let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 -> do
let ptr2 :: Ptr a
ptr2 = Ptr a
ptr1 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent2 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i2 -> do
let ptr3 :: Ptr a
ptr3 = Ptr a
ptr2 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min2 forall a. Num a => a -> a -> a
+ Int32
stride2 forall a. Num a => a -> a -> a
* Int32
i2)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent3 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i3 ->
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr3 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min3 forall a. Num a => a -> a -> a
+ Int32
stride3 forall a. Num a => a -> a -> a
* Int32
i3))
withCopiedToHost :: Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost :: forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer -> Ptr RawHalideBuffer
buf) IO b
action = do
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek Ptr RawHalideBuffer
buf
let allocate :: IO ()
allocate = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDevice forall a. Eq a => a -> a -> Bool
/= Word64
0) forall a b. (a -> b) -> a -> b
$ Ptr RawHalideBuffer -> IO ()
allocateHostMemory Ptr RawHalideBuffer
buf
deallocate :: IO ()
deallocate = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDevice forall a. Eq a => a -> a -> Bool
/= Word64
0) forall a b. (a -> b) -> a -> b
$ Ptr RawHalideBuffer -> IO ()
freeHostMemory Ptr RawHalideBuffer
buf
forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ IO ()
allocate IO ()
deallocate forall a b. (a -> b) -> a -> b
$ do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDevice forall a. Eq a => a -> a -> Bool
/= Word64
0) forall a b. (a -> b) -> a -> b
$ do
Bool -> Ptr RawHalideBuffer -> IO ()
setDeviceDirty Bool
True Ptr RawHalideBuffer
buf
HasCallStack => Ptr RawHalideBuffer -> IO ()
bufferCopyToHost Ptr RawHalideBuffer
buf
IO b
action