{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Array.Accelerate.IO.Data.Vector.Primitive (
Vectors,
toVectors,
fromVectors,
) where
import Data.Vector.Primitive
import Data.Array.Accelerate.IO.Data.Vector.Primitive.Internal
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Sugar hiding ( Vector )
import Data.Array.Accelerate.Array.Unique
import Data.Array.Accelerate.Error
import qualified Data.Array.Accelerate.Array.Representation as R
import Data.Int
import Data.Word
type family Vectors e :: *
type instance Vectors () = ()
type instance Vectors Int = Vector Int
type instance Vectors Int8 = Vector Int8
type instance Vectors Int16 = Vector Int16
type instance Vectors Int32 = Vector Int32
type instance Vectors Int64 = Vector Int64
type instance Vectors Word = Vector Word
type instance Vectors Word8 = Vector Word8
type instance Vectors Word16 = Vector Word16
type instance Vectors Word32 = Vector Word32
type instance Vectors Word64 = Vector Word64
type instance Vectors Float = Vector Float
type instance Vectors Double = Vector Double
type instance Vectors Char = Vector Char
type instance Vectors (a,b) = (Vectors a, Vectors b)
{-# INLINE fromVectors #-}
fromVectors :: (Shape sh, Elt e) => sh -> Vectors (EltRepr e) -> Array sh e
fromVectors sh vecs = Array (fromElt sh) (aux arrayElt vecs)
where
{-# INLINE wrap #-}
wrap :: Prim a => Vector a -> UniqueArray a
wrap v@(Vector _ l _)
= $boundsCheck "fromVectors" "shape mismatch" (size sh == l)
$ uniqueArrayOfVector v
{-# INLINE aux #-}
aux :: ArrayEltR e -> Vectors e -> ArrayData e
aux ArrayEltRunit _ = AD_Unit
aux ArrayEltRint v = AD_Int (wrap v)
aux ArrayEltRint8 v = AD_Int8 (wrap v)
aux ArrayEltRint16 v = AD_Int16 (wrap v)
aux ArrayEltRint32 v = AD_Int32 (wrap v)
aux ArrayEltRint64 v = AD_Int64 (wrap v)
aux ArrayEltRword v = AD_Word (wrap v)
aux ArrayEltRword8 v = AD_Word8 (wrap v)
aux ArrayEltRword16 v = AD_Word16 (wrap v)
aux ArrayEltRword32 v = AD_Word32 (wrap v)
aux ArrayEltRword64 v = AD_Word64 (wrap v)
aux ArrayEltRchar v = AD_Char (wrap v)
aux ArrayEltRfloat v = AD_Float (wrap v)
aux ArrayEltRdouble v = AD_Double (wrap v)
aux (ArrayEltRpair ad1 ad2) (v1,v2) = AD_Pair (aux ad1 v1) (aux ad2 v2)
aux _ _ = $internalError "fromVectors" "unsupported type"
{-# INLINE toVectors #-}
toVectors :: (Shape sh, Elt e) => Array sh e -> Vectors (EltRepr e)
toVectors (Array sh adata) = aux arrayElt adata
where
n :: Int
!n = R.size sh
{-# INLINE wrap #-}
wrap :: Prim a => UniqueArray a -> Vector a
wrap ua = vectorOfUniqueArray n ua
{-# INLINE aux #-}
aux :: ArrayEltR e -> ArrayData e -> Vectors e
aux ArrayEltRunit AD_Unit = ()
aux ArrayEltRint (AD_Int v) = wrap v
aux ArrayEltRint8 (AD_Int8 v) = wrap v
aux ArrayEltRint16 (AD_Int16 v) = wrap v
aux ArrayEltRint32 (AD_Int32 v) = wrap v
aux ArrayEltRint64 (AD_Int64 v) = wrap v
aux ArrayEltRword (AD_Word v) = wrap v
aux ArrayEltRword8 (AD_Word8 v) = wrap v
aux ArrayEltRword16 (AD_Word16 v) = wrap v
aux ArrayEltRword32 (AD_Word32 v) = wrap v
aux ArrayEltRword64 (AD_Word64 v) = wrap v
aux ArrayEltRchar (AD_Char v) = wrap v
aux ArrayEltRfloat (AD_Float v) = wrap v
aux ArrayEltRdouble (AD_Double v) = wrap v
aux (ArrayEltRpair ad1 ad2) (AD_Pair v1 v2) = (aux ad1 v1, aux ad2 v2)
aux _ _ = $internalError "toVectors" "unsupported type"