{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Language.Halide.ArrayFire () where

import ArrayFire (AFType, Array)
import ArrayFire qualified as AF
import Data.Proxy
import GHC.TypeLits
import Language.Halide

forceNumDims :: AFType a => Array a -> Int -> [Int]
forceNumDims :: forall a. AFType a => Array a -> Int -> [Int]
forceNumDims Array a
arr Int
n
  | forall a. AFType a => Array a -> Int
AF.getNumDims Array a
arr forall a. Ord a => a -> a -> Bool
<= Int
n = forall a. Int -> [a] -> [a]
take Int
n [Int]
shape
  | Bool
otherwise =
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
        [Char]
"cannot treat a "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show (forall a. AFType a => Array a -> Int
AF.getNumDims Array a
arr)
          forall a. Semigroup a => a -> a -> a
<> [Char]
"-dimensional array as a "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
n
          forall a. Semigroup a => a -> a -> a
<> [Char]
"-dimensional buffer"
  where
    shape :: [Int]
shape = let (Int
d0, Int
d1, Int
d2, Int
d3) = forall a. AFType a => Array a -> (Int, Int, Int, Int)
AF.getDims Array a
arr in [Int
d0, Int
d1, Int
d2, Int
d3]

instance (IsHalideType a, AFType a, KnownNat n, n <= 4) => IsHalideBuffer (Array a) n a where
  withHalideBufferImpl :: forall b. Array a -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBufferImpl Array a
arr Ptr (HalideBuffer n a) -> IO b
action = case forall a. Array a -> Backend
AF.getBackend Array a
arr of
    Backend
AF.CPU -> forall a b. AFType a => Array a -> (Ptr a -> IO b) -> IO b
AF.withDevicePtr Array a
arr forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
ptr [Int]
shape Ptr (HalideBuffer n a) -> IO b
action
    Backend
AF.CUDA -> forall a. HasCallStack => a
undefined
    Backend
AF.OpenCL -> forall a. HasCallStack => a
undefined
    Backend
AF.Default -> forall a. HasCallStack => [Char] -> a
error [Char]
"do not know how to handle 'Default' backend"
    where
      shape :: [Int]
shape = forall a. AFType a => Array a -> Int -> [Int]
forceNumDims Array a
arr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)