{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Data.Array.Accelerate.IO.Data.Vector.Generic.Mutable
where
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Unique
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape
import Data.Array.Accelerate.Type
import qualified Data.Vector.Generic.Mutable as V
import Control.Monad.Primitive
import Data.Typeable
import Foreign.Marshal.Utils
import Foreign.Storable
import Prelude hiding ( length )
import GHC.Base
import GHC.ForeignPtr
data MArray sh s e where
MArray :: EltR sh
-> MutableArrayData (EltR e)
-> MArray sh s e
deriving instance Typeable MArray
type MVector = MArray DIM1
instance Elt e => V.MVector MVector e where
{-# INLINE basicLength #-}
{-# INLINE basicUnsafeSlice #-}
{-# INLINE basicOverlaps #-}
{-# INLINE basicUnsafeNew #-}
{-# INLINE basicInitialize #-}
{-# INLINE basicUnsafeRead #-}
{-# INLINE basicUnsafeWrite #-}
{-# INLINE basicUnsafeCopy #-}
basicLength (MArray ((), n) _) = n
basicUnsafeSlice j m (MArray _ mad) = MArray ((),m) (go (eltR @e) mad)
where
go :: TypeR a -> MutableArrayData a -> MutableArrayData a
go TupRunit () = ()
go (TupRpair aR1 aR2) (a1, a2) = (go aR1 a1, go aR2 a2)
go (TupRsingle aR) a = scalar aR a
scalar :: ScalarType a -> MutableArrayData a -> MutableArrayData a
scalar (SingleScalarType t) a = single t a 1
scalar (VectorScalarType t) a = vector t a
vector :: VectorType a -> MutableArrayData a -> MutableArrayData a
vector (VectorType w t) a
| SingleArrayDict <- singleArrayDict t
= single t a w
single :: SingleType a -> MutableArrayData a -> Int -> MutableArrayData a
single (NumSingleType t) = num t
num :: NumType a -> MutableArrayData a -> Int -> MutableArrayData a
num (IntegralNumType t) = integral t
num (FloatingNumType t) = floating t
integral :: IntegralType a -> MutableArrayData a -> Int -> MutableArrayData a
integral TypeInt = slice
integral TypeInt8 = slice
integral TypeInt16 = slice
integral TypeInt32 = slice
integral TypeInt64 = slice
integral TypeWord = slice
integral TypeWord8 = slice
integral TypeWord16 = slice
integral TypeWord32 = slice
integral TypeWord64 = slice
floating :: FloatingType a -> MutableArrayData a -> Int -> MutableArrayData a
floating TypeHalf = slice
floating TypeFloat = slice
floating TypeDouble = slice
slice :: forall a. Storable a => UniqueArray a -> Int -> UniqueArray a
slice (UniqueArray uid (Lifetime lft w fp)) s =
UniqueArray uid (Lifetime lft w (plusForeignPtr fp (j * s * sizeOf (undefined::a))))
basicOverlaps (MArray ((),m) mad1) (MArray ((),n) mad2) = go (eltR @e) mad1 mad2
where
go :: TypeR a -> MutableArrayData a -> MutableArrayData a -> Bool
go TupRunit () () = False
go (TupRpair aR1 aR2) (a1, a2) (b1, b2) = go aR1 a1 b1 || go aR2 a2 b2
go (TupRsingle aR) a b = scalar aR a b
scalar :: ScalarType a -> MutableArrayData a -> MutableArrayData a -> Bool
scalar (SingleScalarType t) a b = single t a b 1
scalar (VectorScalarType t) a b = vector t a b
vector :: VectorType a -> MutableArrayData a -> MutableArrayData a -> Bool
vector (VectorType w t) a b
| SingleArrayDict <- singleArrayDict t
= single t a b w
single :: SingleType a -> MutableArrayData a -> MutableArrayData a -> Int -> Bool
single (NumSingleType t) = num t
num :: NumType a -> MutableArrayData a -> MutableArrayData a -> Int -> Bool
num (IntegralNumType t) = integral t
num (FloatingNumType t) = floating t
integral :: IntegralType a -> MutableArrayData a -> MutableArrayData a -> Int -> Bool
integral TypeInt = overlaps
integral TypeInt8 = overlaps
integral TypeInt16 = overlaps
integral TypeInt32 = overlaps
integral TypeInt64 = overlaps
integral TypeWord = overlaps
integral TypeWord8 = overlaps
integral TypeWord16 = overlaps
integral TypeWord32 = overlaps
integral TypeWord64 = overlaps
floating :: FloatingType a -> MutableArrayData a -> MutableArrayData a -> Int -> Bool
floating TypeHalf = overlaps
floating TypeFloat = overlaps
floating TypeDouble = overlaps
overlaps :: forall a. Storable a => UniqueArray a -> UniqueArray a -> Int -> Bool
overlaps (UniqueArray _ (Lifetime _ _ (ForeignPtr addr1# c1))) (UniqueArray _ (Lifetime _ _ (ForeignPtr addr2# c2))) s =
let i = I# (addr2Int# addr1#)
j = I# (addr2Int# addr2#)
k = s * sizeOf (undefined::a)
in
same c1 c2 && (between i j (j + n*k) || between j i (i + m*k))
same :: ForeignPtrContents -> ForeignPtrContents -> Bool
same (PlainPtr mba1#) (PlainPtr mba2#) = isTrue# (sameMutableByteArray# mba1# mba2#)
same (MallocPtr mba1# _) (MallocPtr mba2# _) = isTrue# (sameMutableByteArray# mba1# mba2#)
same _ _ = False
between :: Int -> Int -> Int -> Bool
between x y z = x >= y && x < z
basicUnsafeNew n = unsafePrimToPrim $ MArray ((),n) <$> newArrayData (eltR @e) n
basicInitialize (MArray ((),n) mad) = unsafePrimToPrim $ go (eltR @e) mad
where
go :: TypeR a -> MutableArrayData a -> IO ()
go TupRunit () = return ()
go (TupRpair aR1 aR2) (a1, a2) = go aR1 a1 >> go aR2 a2
go (TupRsingle aR) a = scalar aR a
scalar :: ScalarType a -> MutableArrayData a -> IO ()
scalar (SingleScalarType t) a = single t a 1
scalar (VectorScalarType t) a = vector t a
vector :: VectorType a -> MutableArrayData a -> IO ()
vector (VectorType w t) a
| SingleArrayDict <- singleArrayDict t
= single t a w
single :: SingleType a -> MutableArrayData a -> Int -> IO ()
single (NumSingleType t) = num t
num :: NumType a -> MutableArrayData a -> Int -> IO ()
num (IntegralNumType t) = integral t
num (FloatingNumType t) = floating t
integral :: IntegralType a -> MutableArrayData a -> Int -> IO ()
integral TypeInt = initialise
integral TypeInt8 = initialise
integral TypeInt16 = initialise
integral TypeInt32 = initialise
integral TypeInt64 = initialise
integral TypeWord = initialise
integral TypeWord8 = initialise
integral TypeWord16 = initialise
integral TypeWord32 = initialise
integral TypeWord64 = initialise
floating :: FloatingType a -> MutableArrayData a -> Int -> IO ()
floating TypeHalf = initialise
floating TypeFloat = initialise
floating TypeDouble = initialise
initialise :: forall a. Storable a => UniqueArray a -> Int -> IO ()
initialise ua s = withUniqueArrayPtr ua $ \p -> fillBytes p 0 (n * s * sizeOf (undefined::a))
basicUnsafeRead (MArray _ mad) i = unsafePrimToPrim $ toElt <$> readArrayData (eltR @e) mad i
basicUnsafeWrite (MArray _ mad) i v = unsafePrimToPrim $ writeArrayData (eltR @e) mad i (fromElt v)
basicUnsafeCopy (MArray _ dst) (MArray ((),n) src) = unsafePrimToPrim $ go (eltR @e) dst src
where
go :: TypeR a -> MutableArrayData a -> MutableArrayData a -> IO ()
go TupRunit () () = return ()
go (TupRpair aR1 aR2) (a1, a2) (b1, b2) = go aR1 a1 b1 >> go aR2 a2 b2
go (TupRsingle aR) a b = scalar aR a b
scalar :: ScalarType a -> MutableArrayData a -> MutableArrayData a -> IO ()
scalar (SingleScalarType t) a b = single t a b 1
scalar (VectorScalarType t) a b = vector t a b
vector :: VectorType a -> MutableArrayData a -> MutableArrayData a -> IO ()
vector (VectorType w t) a b
| SingleArrayDict <- singleArrayDict t
= single t a b w
single :: SingleType a -> MutableArrayData a -> MutableArrayData a -> Int -> IO ()
single (NumSingleType t) = num t
num :: NumType a -> MutableArrayData a -> MutableArrayData a -> Int -> IO ()
num (IntegralNumType t) = integral t
num (FloatingNumType t) = floating t
integral :: IntegralType a -> MutableArrayData a -> MutableArrayData a -> Int -> IO ()
integral TypeInt = copy
integral TypeInt8 = copy
integral TypeInt16 = copy
integral TypeInt32 = copy
integral TypeInt64 = copy
integral TypeWord = copy
integral TypeWord8 = copy
integral TypeWord16 = copy
integral TypeWord32 = copy
integral TypeWord64 = copy
floating :: FloatingType a -> MutableArrayData a -> MutableArrayData a -> Int -> IO ()
floating TypeHalf = copy
floating TypeFloat = copy
floating TypeDouble = copy
copy :: forall a. Storable a => UniqueArray a -> UniqueArray a -> Int -> IO ()
copy uu uv s =
withUniqueArrayPtr uu $ \u ->
withUniqueArrayPtr uv $ \v ->
copyBytes u v (n * s * sizeOf (undefined::a))
#if !MIN_VERSION_base(4,10,0)
plusForeignPtr :: ForeignPtr a -> Int -> ForeignPtr b
plusForeignPtr (ForeignPtr addr# c) (I# i#) = ForeignPtr (plusAddr# addr# i#) c
#endif