{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Massiv.Array.Manifest.Primitive
( P(..)
, Array(..)
, Prim
, vectorToByteArray
, toByteArray
, fromByteArray
, toMutableByteArray
, fromMutableByteArray
) where
import Control.DeepSeq (NFData (..), deepseq)
import Control.Monad.ST (runST)
import Data.Massiv.Array.Delayed.Internal (eq, ord)
import Data.Massiv.Array.Manifest.Internal
import Data.Massiv.Array.Manifest.List as A
import Data.Massiv.Array.Mutable
import Data.Massiv.Array.Unsafe (unsafeGenerateArray,
unsafeGenerateArrayP)
import Data.Massiv.Core.Common
import Data.Massiv.Core.List
import Data.Primitive (sizeOf)
import Data.Primitive.ByteArray
import Data.Primitive.Types
import qualified Data.Vector.Primitive as VP
import GHC.Base (Int (..))
import GHC.Exts as GHC (IsList (..))
import GHC.Prim
import Prelude hiding (mapM)
#include "massiv.h"
data P = P deriving Show
type instance EltRepr P ix = M
data instance Array P ix e = PArray { pComp :: !Comp
, pSize :: !ix
, pData :: {-# UNPACK #-} !ByteArray
}
instance Index ix => NFData (Array P ix e) where
rnf (PArray c sz a) = c `deepseq` sz `deepseq` a `seq` ()
{-# INLINE rnf #-}
instance (Prim e, Eq e, Index ix) => Eq (Array P ix e) where
(==) = eq (==)
{-# INLINE (==) #-}
instance (Prim e, Ord e, Index ix) => Ord (Array P ix e) where
compare = ord compare
{-# INLINE compare #-}
instance (Prim e, Index ix) => Construct P ix e where
getComp = pComp
{-# INLINE getComp #-}
setComp c arr = arr { pComp = c }
{-# INLINE setComp #-}
unsafeMakeArray Seq !sz f = unsafeGenerateArray sz f
unsafeMakeArray (ParOn wIds) !sz f = unsafeGenerateArrayP wIds sz f
{-# INLINE unsafeMakeArray #-}
elemsByteArray :: Prim a => a -> ByteArray -> Int
elemsByteArray dummy a = sizeofByteArray a `div` sizeOf dummy
{-# INLINE elemsByteArray #-}
instance (Prim e, Index ix) => Source P ix e where
unsafeLinearIndex (PArray _ _ a) =
INDEX_CHECK("(Source P ix e).unsafeLinearIndex",
elemsByteArray (undefined :: e), indexByteArray) a
{-# INLINE unsafeLinearIndex #-}
instance (Prim e, Index ix) => Size P ix e where
size = pSize
{-# INLINE size #-}
unsafeResize !sz !arr = arr { pSize = sz }
{-# INLINE unsafeResize #-}
unsafeExtract !sIx !newSz !arr = unsafeExtract sIx newSz (toManifest arr)
{-# INLINE unsafeExtract #-}
instance {-# OVERLAPPING #-} Prim e => Slice P Ix1 e where
unsafeSlice arr i _ _ = Just (unsafeLinearIndex arr i)
{-# INLINE unsafeSlice #-}
instance ( Prim e
, Index ix
, Index (Lower ix)
, Elt P ix e ~ Elt M ix e
, Elt M ix e ~ Array M (Lower ix) e
) =>
Slice P ix e where
unsafeSlice arr = unsafeSlice (toManifest arr)
{-# INLINE unsafeSlice #-}
instance {-# OVERLAPPING #-} Prim e => OuterSlice P Ix1 e where
unsafeOuterSlice = unsafeLinearIndex
{-# INLINE unsafeOuterSlice #-}
instance ( Prim e
, Index ix
, Index (Lower ix)
, Elt M ix e ~ Array M (Lower ix) e
, Elt P ix e ~ Array M (Lower ix) e
) =>
OuterSlice P ix e where
unsafeOuterSlice arr = unsafeOuterSlice (toManifest arr)
{-# INLINE unsafeOuterSlice #-}
instance {-# OVERLAPPING #-} Prim e => InnerSlice P Ix1 e where
unsafeInnerSlice arr _ = unsafeLinearIndex arr
{-# INLINE unsafeInnerSlice #-}
instance ( Prim e
, Index ix
, Index (Lower ix)
, Elt M ix e ~ Array M (Lower ix) e
, Elt P ix e ~ Array M (Lower ix) e
) =>
InnerSlice P ix e where
unsafeInnerSlice arr = unsafeInnerSlice (toManifest arr)
{-# INLINE unsafeInnerSlice #-}
instance (Index ix, Prim e) => Manifest P ix e where
unsafeLinearIndexM (PArray _ _ a) =
INDEX_CHECK("(Manifest P ix e).unsafeLinearIndexM",
elemsByteArray (undefined :: e), indexByteArray) a
{-# INLINE unsafeLinearIndexM #-}
elemsMutableByteArray :: Prim a => a -> MutableByteArray s -> Int
elemsMutableByteArray dummy a = sizeofMutableByteArray a `div` sizeOf dummy
{-# INLINE elemsMutableByteArray #-}
instance (Index ix, Prim e) => Mutable P ix e where
data MArray s P ix e = MPArray !ix !(MutableByteArray s)
msize (MPArray sz _) = sz
{-# INLINE msize #-}
unsafeThaw (PArray _ sz a) = MPArray sz <$> unsafeThawByteArray a
{-# INLINE unsafeThaw #-}
unsafeFreeze comp (MPArray sz a) = PArray comp sz <$> unsafeFreezeByteArray a
{-# INLINE unsafeFreeze #-}
unsafeNew sz = MPArray sz <$> newByteArray (I# (totalSize# sz (undefined :: e)))
{-# INLINE unsafeNew #-}
unsafeNewZero sz = do
let !szBytes = I# (totalSize# sz (undefined :: e))
barr <- newByteArray szBytes
fillByteArray barr 0 szBytes 0
return $ MPArray sz barr
{-# INLINE unsafeNewZero #-}
unsafeLinearRead (MPArray _ ma) =
INDEX_CHECK("(Mutable P ix e).unsafeLinearRead",
elemsMutableByteArray (undefined :: e), readByteArray) ma
{-# INLINE unsafeLinearRead #-}
unsafeLinearWrite (MPArray _ ma) =
INDEX_CHECK("(Mutable P ix e).unsafeLinearWrite",
elemsMutableByteArray (undefined :: e), writeByteArray) ma
{-# INLINE unsafeLinearWrite #-}
unsafeNewA sz (State s#) =
let kb# = totalSize# sz (undefined :: e)
!(# s'#, mba# #) = newByteArray# kb# s# in
pure (State s'#, MPArray sz (MutableByteArray mba#))
{-# INLINE unsafeNewA #-}
unsafeThawA (PArray _ sz (ByteArray ba#)) s =
pure (s, MPArray sz (MutableByteArray (unsafeCoerce# ba#)))
{-# INLINE unsafeThawA #-}
unsafeFreezeA comp (MPArray sz (MutableByteArray mba#)) (State s#) =
let !(# s'#, ba# #) = unsafeFreezeByteArray# mba# s# in
pure (State s'#, PArray comp sz (ByteArray ba#))
{-# INLINE unsafeFreezeA #-}
unsafeLinearWriteA (MPArray _ (MutableByteArray mba#)) (I# i#) val (State s#) =
pure (State (writeByteArray# mba# i# val s#))
{-# INLINE unsafeLinearWriteA #-}
totalSize# :: (Index ix, Prim e) => ix -> e -> Int#
totalSize# sz dummy = k# *# sizeOf# dummy
where
!(I# k#) = totalElem sz
{-# INLINE totalSize# #-}
instance ( VP.Prim e
, IsList (Array L ix e)
, Nested LN ix e
, Nested L ix e
, Ragged L ix e
) =>
IsList (Array P ix e) where
type Item (Array P ix e) = Item (Array L ix e)
fromList = A.fromLists' Seq
{-# INLINE fromList #-}
toList = GHC.toList . toListArray
{-# INLINE toList #-}
vectorToByteArray :: forall e . VP.Prim e => VP.Vector e -> ByteArray
vectorToByteArray (VP.Vector start len arr) =
if start == 0
then arr
else runST $ do
marr <- newByteArray len
let elSize = sizeOf (undefined :: e)
copyByteArray marr 0 arr (start * elSize) (len * elSize)
unsafeFreezeByteArray marr
{-# INLINE vectorToByteArray #-}
primArrayDummy :: arr P ix e -> e
primArrayDummy = undefined
{-# INLINE primArrayDummy #-}
toByteArray :: Array P ix e -> ByteArray
toByteArray = pData
{-# INLINE toByteArray #-}
fromByteArray :: (Index ix, Prim e) => Comp -> ix -> ByteArray -> Maybe (Array P ix e)
fromByteArray comp sz ba
| totalElem sz /= elemsByteArray (primArrayDummy arr) ba = Nothing
| otherwise = Just arr
where
arr = PArray comp sz ba
{-# INLINE fromByteArray #-}
toMutableByteArray :: MArray s P ix e -> MutableByteArray s
toMutableByteArray (MPArray _ mba) = mba
{-# INLINE toMutableByteArray #-}
fromMutableByteArray :: (Index ix, Prim e) => ix -> MutableByteArray s -> Maybe (MArray s P ix e)
fromMutableByteArray sz ba
| totalElem sz /= elemsMutableByteArray (primArrayDummy marr) ba = Nothing
| otherwise = Just marr
where
marr = MPArray sz ba
{-# INLINE fromMutableByteArray #-}