{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Massiv.Array.Manifest.Primitive
( P(..)
, Array(..)
, Prim
, toPrimitiveVector
, toPrimitiveMVector
, fromPrimitiveVector
, fromPrimitiveMVector
, toByteArray
, toByteArrayM
, unwrapByteArray
, unwrapMutableByteArray
, fromByteArray
, fromByteArrayM
, toMutableByteArray
, toMutableByteArrayM
, fromMutableByteArrayM
, fromMutableByteArray
, shrinkMutableByteArray
, unsafeAtomicReadIntArray
, unsafeAtomicWriteIntArray
, unsafeCasIntArray
, unsafeAtomicModifyIntArray
, unsafeAtomicAddIntArray
, unsafeAtomicSubIntArray
, unsafeAtomicAndIntArray
, unsafeAtomicNandIntArray
, unsafeAtomicOrIntArray
, unsafeAtomicXorIntArray
) where
import Control.DeepSeq (NFData(..), deepseq)
import Control.Monad.Primitive (PrimMonad(primitive), PrimState, primitive_)
import Data.Massiv.Array.Delayed.Pull (eq, ord)
import Data.Massiv.Array.Manifest.Internal
import Data.Massiv.Array.Manifest.List as A
import Data.Massiv.Array.Mutable
import Data.Massiv.Core.Common
import Data.Massiv.Core.List
import Data.Massiv.Vector.Stream as S (steps, isteps)
import Data.Maybe (fromMaybe)
import Data.Primitive (sizeOf)
import Data.Primitive.ByteArray
import Data.Primitive.Types
import qualified Data.Vector.Primitive as VP
import qualified Data.Vector.Primitive.Mutable as MVP
import GHC.Base (Int(..))
import GHC.Exts as GHC
import Prelude hiding (mapM)
import System.IO.Unsafe (unsafePerformIO)
#include "massiv.h"
data P = P deriving Show
data instance Array P ix e = PArray { pComp :: !Comp
, pSize :: !(Sz ix)
, pOffset :: {-# UNPACK #-} !Int
, pData :: {-# UNPACK #-} !ByteArray
}
instance (Ragged L ix e, Show e, Prim e) => Show (Array P ix e) where
showsPrec = showsArrayPrec id
showList = showArrayList
instance Index ix => NFData (Array P ix e) where
rnf (PArray c sz o a) = c `deepseq` sz `deepseq` o `seq` a `seq` ()
{-# INLINE rnf #-}
instance (Prim e, Eq e, Index ix) => Eq (Array P ix e) where
(==) = eq (==)
{-# INLINE (==) #-}
instance (Prim e, Ord e, Index ix) => Ord (Array P ix e) where
compare = ord compare
{-# INLINE compare #-}
instance (Prim e, Index ix) => Construct P ix e where
setComp c arr = arr { pComp = c }
{-# INLINE setComp #-}
makeArray !comp !sz f = unsafePerformIO $ generateArray comp sz (return . f)
{-# INLINE makeArray #-}
instance (Prim e, Index ix) => Source P ix e where
unsafeLinearIndex _arr@(PArray _ _ o a) i =
INDEX_CHECK("(Source P ix e).unsafeLinearIndex",
SafeSz . elemsBA _arr, indexByteArray) a (i + o)
{-# INLINE unsafeLinearIndex #-}
unsafeLinearSlice i k (PArray c _ o a) = PArray c k (i + o) a
{-# INLINE unsafeLinearSlice #-}
instance Index ix => Resize P ix where
unsafeResize !sz !arr = arr { pSize = sz }
{-# INLINE unsafeResize #-}
instance (Prim e, Index ix) => Extract P ix e where
unsafeExtract !sIx !newSz !arr = unsafeExtract sIx newSz (toManifest arr)
{-# INLINE unsafeExtract #-}
instance {-# OVERLAPPING #-} Prim e => Slice P Ix1 e where
unsafeSlice arr i _ _ = pure (unsafeLinearIndex arr i)
{-# INLINE unsafeSlice #-}
instance ( Prim e
, Index ix
, Index (Lower ix)
, Elt P ix e ~ Elt M ix e
, Elt M ix e ~ Array M (Lower ix) e
) =>
Slice P ix e where
unsafeSlice arr = unsafeSlice (toManifest arr)
{-# INLINE unsafeSlice #-}
instance {-# OVERLAPPING #-} Prim e => OuterSlice P Ix1 e where
unsafeOuterSlice = unsafeLinearIndex
{-# INLINE unsafeOuterSlice #-}
instance ( Prim e
, Index ix
, Index (Lower ix)
, Elt M ix e ~ Array M (Lower ix) e
, Elt P ix e ~ Array M (Lower ix) e
) =>
OuterSlice P ix e where
unsafeOuterSlice arr = unsafeOuterSlice (toManifest arr)
{-# INLINE unsafeOuterSlice #-}
instance {-# OVERLAPPING #-} Prim e => InnerSlice P Ix1 e where
unsafeInnerSlice arr _ = unsafeLinearIndex arr
{-# INLINE unsafeInnerSlice #-}
instance ( Prim e
, Index ix
, Index (Lower ix)
, Elt M ix e ~ Array M (Lower ix) e
, Elt P ix e ~ Array M (Lower ix) e
) =>
InnerSlice P ix e where
unsafeInnerSlice arr = unsafeInnerSlice (toManifest arr)
{-# INLINE unsafeInnerSlice #-}
instance (Index ix, Prim e) => Manifest P ix e where
unsafeLinearIndexM _pa@(PArray _ _sz o a) i =
INDEX_CHECK("(Manifest P ix e).unsafeLinearIndexM",
const (Sz (totalElem _sz)), indexByteArray) a (i + o)
{-# INLINE unsafeLinearIndexM #-}
instance (Index ix, Prim e) => Mutable P ix e where
data MArray s P ix e = MPArray !(Sz ix) {-# UNPACK #-} !Int {-# UNPACK #-} !(MutableByteArray s)
msize (MPArray sz _ _) = sz
{-# INLINE msize #-}
unsafeThaw (PArray _ sz o a) = MPArray sz o <$> unsafeThawByteArray a
{-# INLINE unsafeThaw #-}
unsafeFreeze comp (MPArray sz o a) = PArray comp sz o <$> unsafeFreezeByteArray a
{-# INLINE unsafeFreeze #-}
unsafeNew sz
| n <= (maxBound :: Int) `div` eSize = MPArray sz 0 <$> newByteArray (n * eSize)
| otherwise = error $ "Array size is too big: " ++ show sz
where !n = totalElem sz
!eSize = sizeOf (undefined :: e)
{-# INLINE unsafeNew #-}
initialize (MPArray sz o mba) =
fillByteArray mba o (totalElem sz * sizeOf (undefined :: e)) 0
{-# INLINE initialize #-}
unsafeLinearRead _mpa@(MPArray _sz o ma) i =
INDEX_CHECK("(Mutable P ix e).unsafeLinearRead",
const (Sz (totalElem _sz)), readByteArray) ma (i + o)
{-# INLINE unsafeLinearRead #-}
unsafeLinearWrite _mpa@(MPArray _sz o ma) i =
INDEX_CHECK("(Mutable P ix e).unsafeLinearWrite",
const (Sz (totalElem _sz)), writeByteArray) ma (i + o)
{-# INLINE unsafeLinearWrite #-}
unsafeLinearSet (MPArray _ o ma) offset (SafeSz sz) = setByteArray ma (offset + o) sz
{-# INLINE unsafeLinearSet #-}
unsafeLinearCopy (MPArray _ oFrom maFrom) iFrom (MPArray _ oTo maTo) iTo (Sz k) =
copyMutableByteArray maTo ((oTo + iTo) * esz) maFrom ((oFrom + iFrom) * esz) (k * esz)
where esz = sizeOf (undefined :: e)
{-# INLINE unsafeLinearCopy #-}
unsafeArrayLinearCopy (PArray _ _ oFrom aFrom) iFrom (MPArray _ oTo maTo) iTo (Sz k) =
copyByteArray maTo ((oTo + iTo) * esz) aFrom ((oFrom + iFrom) * esz) (k * esz)
where esz = sizeOf (undefined :: e)
{-# INLINE unsafeArrayLinearCopy #-}
unsafeLinearShrink (MPArray _ o ma) sz = do
shrinkMutableByteArray ma ((o + totalElem sz) * sizeOf (undefined :: e))
pure $ MPArray sz o ma
{-# INLINE unsafeLinearShrink #-}
unsafeLinearGrow (MPArray _ o ma) sz =
MPArray sz o <$> resizeMutableByteArrayCompat ma ((o + totalElem sz) * sizeOf (undefined :: e))
{-# INLINE unsafeLinearGrow #-}
instance (Prim e, Index ix) => Load P ix e where
type R P = M
size = pSize
{-# INLINE size #-}
getComp = pComp
{-# INLINE getComp #-}
loadArrayM !scheduler !arr =
splitLinearlyWith_ scheduler (elemsCount arr) (unsafeLinearIndex arr)
{-# INLINE loadArrayM #-}
instance (Prim e, Index ix) => StrideLoad P ix e
instance (Prim e, Index ix) => Stream P ix e where
toStream = S.steps
{-# INLINE toStream #-}
toStreamIx = S.isteps
{-# INLINE toStreamIx #-}
instance ( Prim e
, IsList (Array L ix e)
, Nested LN ix e
, Nested L ix e
, Ragged L ix e
) =>
IsList (Array P ix e) where
type Item (Array P ix e) = Item (Array L ix e)
fromList = A.fromLists' Seq
{-# INLINE fromList #-}
toList = GHC.toList . toListArray
{-# INLINE toList #-}
elemsBA :: forall proxy e . Prim e => proxy e -> ByteArray -> Int
elemsBA _ a = sizeofByteArray a `div` sizeOf (undefined :: e)
{-# INLINE elemsBA #-}
elemsMBA :: forall proxy e s . Prim e => proxy e -> MutableByteArray s -> Int
elemsMBA _ a = sizeofMutableByteArray a `div` sizeOf (undefined :: e)
{-# INLINE elemsMBA #-}
toByteArray :: (Index ix, Prim e) => Array P ix e -> ByteArray
toByteArray arr = fromMaybe (unwrapByteArray $ compute arr) $ toByteArrayM arr
{-# INLINE toByteArray #-}
unwrapByteArray :: Array P ix e -> ByteArray
unwrapByteArray = pData
{-# INLINE unwrapByteArray #-}
toByteArrayM :: (Prim e, Index ix, MonadThrow m) => Array P ix e -> m ByteArray
toByteArrayM arr@PArray {pSize, pData} = do
guardNumberOfElements pSize (Sz (elemsBA arr pData))
pure pData
{-# INLINE toByteArrayM #-}
fromByteArrayM :: (MonadThrow m, Index ix, Prim e) => Comp -> Sz ix -> ByteArray -> m (Array P ix e)
fromByteArrayM comp sz ba =
guardNumberOfElements sz (Sz (elemsBA arr ba)) >> pure arr
where
arr = PArray comp sz 0 ba
{-# INLINE fromByteArrayM #-}
fromByteArray :: forall e . Prim e => Comp -> ByteArray -> Array P Ix1 e
fromByteArray comp ba = PArray comp (SafeSz (elemsBA (Proxy :: Proxy e) ba)) 0 ba
{-# INLINE fromByteArray #-}
unwrapMutableByteArray :: MArray s P ix e -> MutableByteArray s
unwrapMutableByteArray (MPArray _ _ mba) = mba
{-# INLINE unwrapMutableByteArray #-}
toMutableByteArray ::
forall ix e m. (Prim e, Index ix, PrimMonad m)
=> MArray (PrimState m) P ix e
-> m (Bool, MutableByteArray (PrimState m))
toMutableByteArray marr@(MPArray sz offset mbas) =
case toMutableByteArrayM marr of
Just mba -> pure (True, mba)
Nothing -> do
let eSize = sizeOf (undefined :: e)
szBytes = totalElem sz * eSize
mbad <- newPinnedByteArray szBytes
copyMutableByteArray mbad 0 mbas (offset * eSize) szBytes
pure (False, mbad)
{-# INLINE toMutableByteArray #-}
toMutableByteArrayM :: (Index ix, Prim e, MonadThrow m) => MArray s P ix e -> m (MutableByteArray s)
toMutableByteArrayM marr@(MPArray sz _ mba) =
mba <$ guardNumberOfElements sz (Sz (elemsMBA marr mba))
{-# INLINE toMutableByteArrayM #-}
fromMutableByteArrayM ::
(MonadThrow m, Index ix, Prim e) => Sz ix -> MutableByteArray s -> m (MArray s P ix e)
fromMutableByteArrayM sz mba =
marr <$ guardNumberOfElements sz (Sz (elemsMBA marr mba))
where
marr = MPArray sz 0 mba
{-# INLINE fromMutableByteArrayM #-}
fromMutableByteArray :: forall e s . Prim e => MutableByteArray s -> MArray s P Ix1 e
fromMutableByteArray mba = MPArray (SafeSz (elemsMBA (Proxy :: Proxy e) mba)) 0 mba
{-# INLINE fromMutableByteArray #-}
toPrimitiveVector :: Index ix => Array P ix e -> VP.Vector e
toPrimitiveVector PArray {pSize, pOffset, pData} = VP.Vector pOffset (totalElem pSize) pData
{-# INLINE toPrimitiveVector #-}
toPrimitiveMVector :: Index ix => MArray s P ix e -> MVP.MVector s e
toPrimitiveMVector (MPArray sz offset mba) = MVP.MVector offset (totalElem sz) mba
{-# INLINE toPrimitiveMVector #-}
fromPrimitiveVector :: VP.Vector e -> Array P Ix1 e
fromPrimitiveVector (VP.Vector offset len ba) =
PArray {pComp = Seq, pSize = SafeSz len, pOffset = offset, pData = ba}
{-# INLINE fromPrimitiveVector #-}
fromPrimitiveMVector :: MVP.MVector s e -> MArray s P Ix1 e
fromPrimitiveMVector (MVP.MVector offset len mba) = MPArray (SafeSz len) offset mba
{-# INLINE fromPrimitiveMVector #-}
unsafeAtomicReadIntArray ::
(Index ix, PrimMonad m) => MArray (PrimState m) P ix Int -> ix -> m Int
unsafeAtomicReadIntArray _mpa@(MPArray sz o mba) ix =
INDEX_CHECK( "unsafeAtomicReadIntArray"
, SafeSz . elemsMBA _mpa
, \(MutableByteArray mba#) (I# i#) ->
primitive $ \s# ->
case atomicReadIntArray# mba# i# s# of
(# s'#, e# #) -> (# s'#, I# e# #))
mba
(o + toLinearIndex sz ix)
{-# INLINE unsafeAtomicReadIntArray #-}
unsafeAtomicWriteIntArray ::
(Index ix, PrimMonad m) => MArray (PrimState m) P ix Int -> ix -> Int -> m ()
unsafeAtomicWriteIntArray _mpa@(MPArray sz o mba) ix (I# e#) =
INDEX_CHECK( "unsafeAtomicWriteIntArray"
, SafeSz . elemsMBA _mpa
, \(MutableByteArray mba#) (I# i#) ->
primitive_ (atomicWriteIntArray# mba# i# e#))
mba
(o + toLinearIndex sz ix)
{-# INLINE unsafeAtomicWriteIntArray #-}
unsafeCasIntArray ::
(Index ix, PrimMonad m) => MArray (PrimState m) P ix Int -> ix -> Int -> Int -> m Int
unsafeCasIntArray _mpa@(MPArray sz o mba) ix (I# e#) (I# n#) =
INDEX_CHECK( "unsafeCasIntArray"
, SafeSz . elemsMBA _mpa
, \(MutableByteArray mba#) (I# i#) ->
primitive $ \s# ->
case casIntArray# mba# i# e# n# s# of
(# s'#, o# #) -> (# s'#, I# o# #))
mba
(o + toLinearIndex sz ix)
{-# INLINE unsafeCasIntArray #-}
unsafeAtomicModifyIntArray ::
(Index ix, PrimMonad m) => MArray (PrimState m) P ix Int -> ix -> (Int -> Int) -> m Int
unsafeAtomicModifyIntArray _mpa@(MPArray sz o mba) ix f =
INDEX_CHECK("unsafeAtomicModifyIntArray", SafeSz . elemsMBA _mpa, atomicModify)
mba
(o + toLinearIndex sz ix)
where
atomicModify (MutableByteArray mba#) (I# i#) =
let go s# o# =
let !(I# n#) = f (I# o#)
in case casIntArray# mba# i# o# n# s# of
(# s'#, o'# #) ->
case o# ==# o'# of
0# -> go s# o'#
_ -> (# s'#, I# o# #)
in primitive $ \s# ->
case atomicReadIntArray# mba# i# s# of
(# s'#, o# #) -> go s'# o#
{-# INLINE atomicModify #-}
{-# INLINE unsafeAtomicModifyIntArray #-}
unsafeAtomicAddIntArray ::
(Index ix, PrimMonad m) => MArray (PrimState m) P ix Int -> ix -> Int -> m Int
unsafeAtomicAddIntArray _mpa@(MPArray sz o mba) ix (I# e#) =
INDEX_CHECK( "unsafeAtomicAddIntArray"
, SafeSz . elemsMBA _mpa
, \(MutableByteArray mba#) (I# i#) ->
primitive $ \s# ->
case fetchAddIntArray# mba# i# e# s# of
(# s'#, p# #) -> (# s'#, I# p# #))
mba
(o + toLinearIndex sz ix)
{-# INLINE unsafeAtomicAddIntArray #-}
unsafeAtomicSubIntArray ::
(Index ix, PrimMonad m) => MArray (PrimState m) P ix Int -> ix -> Int -> m Int
unsafeAtomicSubIntArray _mpa@(MPArray sz o mba) ix (I# e#) =
INDEX_CHECK( "unsafeAtomicSubIntArray"
, SafeSz . elemsMBA _mpa
, \(MutableByteArray mba#) (I# i#) ->
primitive $ \s# ->
case fetchSubIntArray# mba# i# e# s# of
(# s'#, p# #) -> (# s'#, I# p# #))
mba
(o + toLinearIndex sz ix)
{-# INLINE unsafeAtomicSubIntArray #-}
unsafeAtomicAndIntArray ::
(Index ix, PrimMonad m) => MArray (PrimState m) P ix Int -> ix -> Int -> m Int
unsafeAtomicAndIntArray _mpa@(MPArray sz o mba) ix (I# e#) =
INDEX_CHECK( "unsafeAtomicAndIntArray"
, SafeSz . elemsMBA _mpa
, \(MutableByteArray mba#) (I# i#) ->
primitive $ \s# ->
case fetchAndIntArray# mba# i# e# s# of
(# s'#, p# #) -> (# s'#, I# p# #))
mba
(o + toLinearIndex sz ix)
{-# INLINE unsafeAtomicAndIntArray #-}
unsafeAtomicNandIntArray ::
(Index ix, PrimMonad m) => MArray (PrimState m) P ix Int -> ix -> Int -> m Int
unsafeAtomicNandIntArray _mpa@(MPArray sz o mba) ix (I# e#) =
INDEX_CHECK( "unsafeAtomicNandIntArray"
, SafeSz . elemsMBA _mpa
, \(MutableByteArray mba#) (I# i#) ->
primitive $ \s# ->
case fetchNandIntArray# mba# i# e# s# of
(# s'#, p# #) -> (# s'#, I# p# #))
mba
(o + toLinearIndex sz ix)
{-# INLINE unsafeAtomicNandIntArray #-}
unsafeAtomicOrIntArray ::
(Index ix, PrimMonad m) => MArray (PrimState m) P ix Int -> ix -> Int -> m Int
unsafeAtomicOrIntArray _mpa@(MPArray sz o mba) ix (I# e#) =
INDEX_CHECK( "unsafeAtomicOrIntArray"
, SafeSz . elemsMBA _mpa
, \(MutableByteArray mba#) (I# i#) ->
primitive $ \s# ->
case fetchOrIntArray# mba# i# e# s# of
(# s'#, p# #) -> (# s'#, I# p# #))
mba
(o + toLinearIndex sz ix)
{-# INLINE unsafeAtomicOrIntArray #-}
unsafeAtomicXorIntArray ::
(Index ix, PrimMonad m) => MArray (PrimState m) P ix Int -> ix -> Int -> m Int
unsafeAtomicXorIntArray _mpa@(MPArray sz o mba) ix (I# e#) =
INDEX_CHECK( "unsafeAtomicXorIntArray"
, SafeSz . elemsMBA _mpa
, \(MutableByteArray mba#) (I# i#) ->
primitive $ \s# ->
case fetchXorIntArray# mba# i# e# s# of
(# s'#, p# #) -> (# s'#, I# p# #))
mba
(o + toLinearIndex sz ix)
{-# INLINE unsafeAtomicXorIntArray #-}
shrinkMutableByteArray :: forall m. (PrimMonad m)
=> MutableByteArray (PrimState m)
-> Int
-> m ()
shrinkMutableByteArray (MutableByteArray arr#) (I# n#)
= primitive_ (shrinkMutableByteArray# arr# n#)
{-# INLINE shrinkMutableByteArray #-}
resizeMutableByteArrayCompat ::
PrimMonad m => MutableByteArray (PrimState m) -> Int -> m (MutableByteArray (PrimState m))
#if MIN_VERSION_primitive(0,6,4)
resizeMutableByteArrayCompat = resizeMutableByteArray
#else
resizeMutableByteArrayCompat (MutableByteArray arr#) (I# n#) =
primitive
(\s# ->
case resizeMutableByteArray# arr# n# s# of
(# s'#, arr'# #) -> (# s'#, MutableByteArray arr'# #))
#endif
{-# INLINE resizeMutableByteArrayCompat #-}