{-# LANGUAGE TypeFamilies #-}
module Data.Array.Accelerate.IO.Foreign.Ptr
where
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Array.Unique
import Foreign.Ptr
import Foreign.ForeignPtr
import System.IO.Unsafe
type Ptrs e = ArrayPtrs e
{-# INLINE fromPtrs #-}
fromPtrs :: (Shape sh, Elt e) => sh -> Ptrs (EltRepr e) -> Array sh e
fromPtrs sh ps = Array (fromElt sh) (aux arrayElt ps)
where
wrap :: (UniqueArray e -> r) -> Ptr e -> r
wrap k p = k (unsafePerformIO $ newUniqueArray =<< newForeignPtr_ p)
aux :: ArrayEltR e -> Ptrs e -> ArrayData e
aux ArrayEltRunit = const AD_Unit
aux ArrayEltRint = wrap AD_Int
aux ArrayEltRint8 = wrap AD_Int8
aux ArrayEltRint16 = wrap AD_Int16
aux ArrayEltRint32 = wrap AD_Int32
aux ArrayEltRint64 = wrap AD_Int64
aux ArrayEltRword = wrap AD_Word
aux ArrayEltRword8 = wrap AD_Word8
aux ArrayEltRword16 = wrap AD_Word16
aux ArrayEltRword32 = wrap AD_Word32
aux ArrayEltRword64 = wrap AD_Word64
aux ArrayEltRcshort = wrap AD_CShort
aux ArrayEltRcushort = wrap AD_CUShort
aux ArrayEltRcint = wrap AD_CInt
aux ArrayEltRcuint = wrap AD_CUInt
aux ArrayEltRclong = wrap AD_CLong
aux ArrayEltRculong = wrap AD_CULong
aux ArrayEltRcllong = wrap AD_CLLong
aux ArrayEltRcullong = wrap AD_CULLong
aux ArrayEltRhalf = wrap AD_Half
aux ArrayEltRfloat = wrap AD_Float
aux ArrayEltRdouble = wrap AD_Double
aux ArrayEltRcfloat = wrap AD_CFloat
aux ArrayEltRcdouble = wrap AD_CDouble
aux ArrayEltRbool = wrap AD_Bool
aux ArrayEltRchar = wrap AD_Char
aux ArrayEltRcchar = wrap AD_CChar
aux ArrayEltRcschar = wrap AD_CSChar
aux ArrayEltRcuchar = wrap AD_CUChar
aux (ArrayEltRvec2 ae) = AD_V2 . aux ae
aux (ArrayEltRvec3 ae) = AD_V3 . aux ae
aux (ArrayEltRvec4 ae) = AD_V4 . aux ae
aux (ArrayEltRvec8 ae) = AD_V8 . aux ae
aux (ArrayEltRvec16 ae) = AD_V16 . aux ae
aux (ArrayEltRpair ae1 ae2) = \(v1,v2) -> AD_Pair (aux ae1 v1) (aux ae2 v2)
{-# INLINE toPtrs #-}
toPtrs :: (Shape sh, Elt e) => Array sh e -> Ptrs (EltRepr e)
toPtrs (Array _ adata) = ptrsOfArrayData adata