--------------------------------------------------------------------------------
-- |
-- Module      : ArrayFire.Index
-- Copyright   : David Johnson (c) 2019-2020
-- License     : BSD 3
-- Maintainer  : David Johnson <djohnson.m@gmail.com>
-- Stability   : Experimental
-- Portability : GHC
--
-- Functions for indexing into an 'Array'
--
--------------------------------------------------------------------------------
module ArrayFire.Index where

import ArrayFire.Internal.Index
import ArrayFire.Internal.Types
import ArrayFire.FFI
import ArrayFire.Exception

import Foreign

import System.IO.Unsafe
import Control.Exception

-- | Index into an 'Array' by 'Seq'
index
  :: Array a
  -- ^ 'Array' argument
  -> [Seq]
  -- ^ 'Seq' to use for indexing
  -> Array a
index :: forall a. Array a -> [Seq] -> Array a
index (Array ForeignPtr ()
fptr) [Seq]
seqs =
  IO (Array a) -> Array a
forall a. IO a -> a
unsafePerformIO (IO (Array a) -> Array a)
-> ((AFArray -> IO (Array a)) -> IO (Array a))
-> (AFArray -> IO (Array a))
-> Array a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Array a) -> IO (Array a)
forall a. IO a -> IO a
mask_ (IO (Array a) -> IO (Array a))
-> ((AFArray -> IO (Array a)) -> IO (Array a))
-> (AFArray -> IO (Array a))
-> IO (Array a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr () -> (AFArray -> IO (Array a)) -> IO (Array a)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr ()
fptr ((AFArray -> IO (Array a)) -> Array a)
-> (AFArray -> IO (Array a)) -> Array a
forall a b. (a -> b) -> a -> b
$ \AFArray
ptr -> do
    (Ptr AFArray -> IO (Array a)) -> IO (Array a)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr AFArray -> IO (Array a)) -> IO (Array a))
-> (Ptr AFArray -> IO (Array a)) -> IO (Array a)
forall a b. (a -> b) -> a -> b
$ \Ptr AFArray
aptr ->
      [AFSeq] -> (Ptr AFSeq -> IO (Array a)) -> IO (Array a)
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray (Seq -> AFSeq
toAFSeq (Seq -> AFSeq) -> [Seq] -> [AFSeq]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Seq]
seqs) ((Ptr AFSeq -> IO (Array a)) -> IO (Array a))
-> (Ptr AFSeq -> IO (Array a)) -> IO (Array a)
forall a b. (a -> b) -> a -> b
$ \Ptr AFSeq
sptr -> do
        AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr AFArray -> AFArray -> CUInt -> Ptr AFSeq -> IO AFErr
af_index Ptr AFArray
aptr AFArray
ptr CUInt
n Ptr AFSeq
sptr
        ForeignPtr () -> Array a
forall a. ForeignPtr () -> Array a
Array (ForeignPtr () -> Array a) -> IO (ForeignPtr ()) -> IO (Array a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> do
          FinalizerPtr () -> AFArray -> IO (ForeignPtr ())
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr ()
af_release_array_finalizer
            (AFArray -> IO (ForeignPtr ())) -> IO AFArray -> IO (ForeignPtr ())
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr AFArray -> IO AFArray
forall a. Storable a => Ptr a -> IO a
peek Ptr AFArray
aptr
   where
     n :: CUInt
n = Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Seq] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Seq]
seqs)

-- | Lookup an Array by keys along a specified dimension
lookup :: Array a -> Array a -> Int -> Array a
lookup :: forall a. Array a -> Array a -> Int -> Array a
lookup Array a
a Array a
b Int
n = Array a
-> Array a
-> (Ptr AFArray -> AFArray -> AFArray -> IO AFErr)
-> Array a
forall b a.
Array b
-> Array a
-> (Ptr AFArray -> AFArray -> AFArray -> IO AFErr)
-> Array a
op2 Array a
a Array a
b ((Ptr AFArray -> AFArray -> AFArray -> IO AFErr) -> Array a)
-> (Ptr AFArray -> AFArray -> AFArray -> IO AFErr) -> Array a
forall a b. (a -> b) -> a -> b
$ \Ptr AFArray
p AFArray
x AFArray
y -> Ptr AFArray -> AFArray -> AFArray -> CUInt -> IO AFErr
af_lookup Ptr AFArray
p AFArray
x AFArray
y (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)

-- af_err af_assign_seq( af_array *out, const af_array lhs, const unsigned ndims, const af_seq* const indices, const af_array rhs);
-- | Calculates 'mean' of 'Array' along user-specified dimension.
--
-- @
-- >>> print $ mean 0 ( vector @Int 10 [1..] )
-- @
-- @
-- ArrayFire Array
--   [1 1 1 1]
--      5.5000
-- @
-- assignSeq :: Array a -> Int -> [Seq] -> Array a -> Array a
-- assignSeq = error "Not implemneted"

-- af_err af_index_gen(  af_array *out, const af_array in, const dim_t ndims, const af_index_t* indices);
-- | Calculates 'mean' of 'Array' along user-specified dimension.
--
-- @
-- >>> print $ mean 0 ( vector @Int 10 [1..] )
-- @
-- @
-- ArrayFire Array
--   [1 1 1 1]
--      5.5000
-- @
-- indexGen :: Array a -> Int -> [Index a] -> Array a -> Array a
-- indexGen = error "Not implemneted"

-- af_err af_assingn_gen( af_array *out, const af_array lhs, const dim_t ndims, const af_index_t* indices, const af_array rhs);
-- | Calculates 'mean' of 'Array' along user-specified dimension.
--
-- @
-- >>> print $ mean 0 ( vector @Int 10 [1..] )
-- @
-- @
-- ArrayFire Array
--   [1 1 1 1]
--      5.5000
-- @
-- assignGen :: Array a -> Int -> [Index a] -> Array a -> Array a
-- assignGen = error "Not implemneted"

-- af_err af_create_indexers(af_index_t** indexers);
-- af_err af_set_array_indexer(af_index_t* indexer, const af_array idx, const dim_t dim);
-- af_err af_set_seq_indexer(af_index_t* indexer, const af_seq* idx, const dim_t dim, const bool is_batch);
-- af_err af_set_seq_param_indexer(af_index_t* indexer, const double begin, const double end, const double step, const dim_t dim, const bool is_batch);
-- af_err af_release_indexers(af_index_t* indexers);