{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Base
where
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Array.Sugar ( Array(..), EltRepr )
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Unique
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Foreign.Marshal.Alloc
import Foreign.Ptr
import Foreign.Storable
import Foreign.Storable.Complex ( )
import qualified Blas.Primitive.Types as C
encodeTranspose :: Transpose -> C.Transpose
encodeTranspose N = C.NoTrans
encodeTranspose T = C.Trans
encodeTranspose H = C.ConjTrans
withArray
:: forall sh e b. Numeric e
=> Array sh e
-> (ArrayPtrs (EltRepr e) -> IO b)
-> IO b
withArray (Array _ adata) = withArrayData (numericR::NumericR e) adata
withArrayData
:: NumericR e
-> ArrayData (EltRepr e)
-> (ArrayPtrs (EltRepr e) -> IO b)
-> IO b
withArrayData NumericRfloat32 (AD_Float ua) f = withUniqueArrayPtr ua f
withArrayData NumericRfloat64 (AD_Double ua) f = withUniqueArrayPtr ua f
withArrayData NumericRcomplex32 (AD_Pair ad1 ad2) f
| AD_Pair AD_Unit (AD_Float ua_re) <- ad1
, AD_Float ua_im <- ad2
= withUniqueArrayPtr ua_re $ \p_re ->
withUniqueArrayPtr ua_im $ \p_im ->
f (((),p_re), p_im)
withArrayData NumericRcomplex64 (AD_Pair ad1 ad2) f
| AD_Pair AD_Unit (AD_Double ua_re) <- ad1
, AD_Double ua_im <- ad2
= withUniqueArrayPtr ua_re $ \p_re ->
withUniqueArrayPtr ua_im $ \p_im ->
f (((),p_re), p_im)
interleave
:: forall e b. (Storable e, Numeric (Complex e))
=> ArrayPtrs (EltRepr (Complex e))
-> Int
-> (Ptr (Complex e) -> IO b)
-> IO b
interleave (((), p_re), p_im) n k = do
allocaBytesAligned (n * sizeOf (undefined::Complex e)) 16 $ \p_cplx -> do
() <- case numericR :: NumericR (Complex e) of
NumericRcomplex32 -> c_interleave_f32 0 n p_cplx p_re p_im
NumericRcomplex64 -> c_interleave_f64 0 n p_cplx p_re p_im
k p_cplx
deinterleave
:: forall e. (Storable e, Numeric (Complex e))
=> ArrayPtrs (EltRepr (Complex e))
-> Ptr (Complex e)
-> Int
-> IO ()
deinterleave (((), p_re), p_im) p_cplx n =
case numericR :: NumericR (Complex e) of
NumericRcomplex32 -> c_deinterleave_f32 0 n p_re p_im p_cplx
NumericRcomplex64 -> c_deinterleave_f64 0 n p_re p_im p_cplx
foreign import ccall unsafe "interleave_f32"
c_interleave_f32 :: Int -> Int -> Ptr (Complex Float) -> Ptr Float -> Ptr Float -> IO ()
foreign import ccall unsafe "interleave_f64"
c_interleave_f64 :: Int -> Int -> Ptr (Complex Double) -> Ptr Double -> Ptr Double -> IO ()
foreign import ccall unsafe "deinterleave_f32"
c_deinterleave_f32 :: Int -> Int -> Ptr Float -> Ptr Float -> Ptr (Complex Float) -> IO ()
foreign import ccall unsafe "deinterleave_f64"
c_deinterleave_f64 :: Int -> Int -> Ptr Double -> Ptr Double -> Ptr (Complex Double) -> IO ()