module Data.Repa.Array.Material.Strided
( S (..)
, Name (..)
, Array (..)
, unsafeCast
, fromForeignPtr, toForeignPtr)
where
import Data.Repa.Array.Meta.Window
import Data.Repa.Array.Internals.Bulk
import Data.Repa.Array.Internals.Layout
import Data.Repa.Fusion.Unpack
import Data.Word
import qualified Foreign.Storable as S
import qualified Foreign.ForeignPtr as F
import qualified Data.ByteString.Internal as BS
#include "repa-array.h"
data S = Strided
{ stridedLength :: !Int }
deriving (Show, Eq)
instance Layout S where
data Name S = S
type Index S = Int
name = S
create S len = Strided len
extent (Strided len) = len
toIndex _ ix = ix
fromIndex _ ix = ix
deriving instance Eq (Name S)
deriving instance Show (Name S)
instance S.Storable a => Bulk S a where
data Array S a
= SArray
{ sArrayStartBytes :: !Int
, sArrayStrideBytes :: !Int
, sArrayLenElems :: !Int
, sArrayPtr :: !(F.ForeignPtr a) }
layout (SArray _ _ len _)
= Strided len
index (SArray start stride len fptr) ix
= BS.inlinePerformIO
$ F.withForeignPtr fptr
$ \ptr -> S.peekByteOff ptr
(start + (toIndex (Strided len) ix) * stride)
deriving instance (S.Storable a, Show a) => Show (Array S a)
instance Unpack (Array S a) (Int, Int, Int, F.ForeignPtr a) where
unpack (SArray start stride len fptr) = (start, stride, len, fptr)
repack _ (start, stride, len, fptr) = (SArray start stride len fptr)
instance S.Storable a => Windowable S a where
window startElems' lenElems'
(SArray startBytes strideBytes _lenElems fptr)
= let lenElem = S.sizeOf (undefined :: a)
in SArray (startBytes + (lenElem * startElems'))
strideBytes lenElems' fptr
unsafeCast
:: (S.Storable a, S.Storable b)
=> Array S a -> Array S b
unsafeCast (SArray startBytes strideBytes lenElems fptr)
= (SArray startBytes strideBytes lenElems $ F.castForeignPtr fptr)
fromForeignPtr
:: Int
-> Int
-> Int
-> F.ForeignPtr a
-> Array S a
fromForeignPtr startBytes strideBytes lenElems fptr
= SArray startBytes strideBytes lenElems fptr
toForeignPtr
:: Array S a
-> (Int, Int, Int, F.ForeignPtr a)
toForeignPtr (SArray startBytes strideBytes lenElems fptr)
= (startBytes, strideBytes, lenElems, fptr)