{-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module Language.Halide.ArrayFire () where import ArrayFire (AFType, Array) import ArrayFire qualified as AF import Data.Proxy import GHC.TypeLits import Language.Halide forceNumDims :: AFType a => Array a -> Int -> [Int] forceNumDims :: forall a. AFType a => Array a -> Int -> [Int] forceNumDims Array a arr Int n | forall a. AFType a => Array a -> Int AF.getNumDims Array a arr forall a. Ord a => a -> a -> Bool <= Int n = forall a. Int -> [a] -> [a] take Int n [Int] shape | Bool otherwise = forall a. HasCallStack => [Char] -> a error forall a b. (a -> b) -> a -> b $ [Char] "cannot treat a " forall a. Semigroup a => a -> a -> a <> forall a. Show a => a -> [Char] show (forall a. AFType a => Array a -> Int AF.getNumDims Array a arr) forall a. Semigroup a => a -> a -> a <> [Char] "-dimensional array as a " forall a. Semigroup a => a -> a -> a <> forall a. Show a => a -> [Char] show Int n forall a. Semigroup a => a -> a -> a <> [Char] "-dimensional buffer" where shape :: [Int] shape = let (Int d0, Int d1, Int d2, Int d3) = forall a. AFType a => Array a -> (Int, Int, Int, Int) AF.getDims Array a arr in [Int d0, Int d1, Int d2, Int d3] instance (IsHalideType a, AFType a, KnownNat n, n <= 4) => IsHalideBuffer (Array a) n a where withHalideBufferImpl :: forall b. Array a -> (Ptr (HalideBuffer n a) -> IO b) -> IO b withHalideBufferImpl Array a arr Ptr (HalideBuffer n a) -> IO b action = case forall a. Array a -> Backend AF.getBackend Array a arr of Backend AF.CPU -> forall a b. AFType a => Array a -> (Ptr a -> IO b) -> IO b AF.withDevicePtr Array a arr forall a b. (a -> b) -> a -> b $ \Ptr a ptr -> forall (n :: Nat) a b. (HasCallStack, KnownNat n, IsHalideType a) => Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b bufferFromPtrShape Ptr a ptr [Int] shape Ptr (HalideBuffer n a) -> IO b action Backend AF.CUDA -> forall a. HasCallStack => a undefined Backend AF.OpenCL -> forall a. HasCallStack => a undefined Backend AF.Default -> forall a. HasCallStack => [Char] -> a error [Char] "do not know how to handle 'Default' backend" where shape :: [Int] shape = forall a. AFType a => Array a -> Int -> [Int] forceNumDims Array a arr forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a b. (Integral a, Num b) => a -> b fromIntegral forall a b. (a -> b) -> a -> b $ forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Integer natVal (forall {k} (t :: k). Proxy t Proxy @n)