{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
module Streamly.Internal.Data.Prim.Array.Types
(
PrimArray(..)
, MutablePrimArray(..)
, newPrimArray
, resizeMutablePrimArray
, shrinkMutablePrimArray
, writePrimArray
, indexPrimArray
, unsafeFreezePrimArray
, sizeofPrimArray
, foldrPrimArray
, foldlPrimArray'
) where
import GHC.Exts
import Data.Primitive.Types
import Data.Primitive.ByteArray (ByteArray(..))
import Control.Monad.Primitive
import qualified Data.Primitive.ByteArray as PB
data PrimArray a = PrimArray ByteArray#
data MutablePrimArray s a = MutablePrimArray (MutableByteArray# s)
sameByteArray :: ByteArray# -> ByteArray# -> Bool
sameByteArray ba1 ba2 =
case reallyUnsafePtrEquality# (unsafeCoerce# ba1 :: ()) (unsafeCoerce# ba2 :: ()) of
r -> isTrue# r
instance (Eq a, Prim a) => Eq (PrimArray a) where
a1@(PrimArray ba1#) == a2@(PrimArray ba2#)
| sameByteArray ba1# ba2# = True
| sz1 /= sz2 = False
| otherwise = loop (quot sz1 (sizeOf (undefined :: a)) - 1)
where
sz1 = PB.sizeofByteArray (ByteArray ba1#)
sz2 = PB.sizeofByteArray (ByteArray ba2#)
loop !i
| i < 0 = True
| otherwise = indexPrimArray a1 i == indexPrimArray a2 i && loop (i-1)
{-# INLINE (==) #-}
instance (Ord a, Prim a) => Ord (PrimArray a) where
compare a1@(PrimArray ba1#) a2@(PrimArray ba2#)
| sameByteArray ba1# ba2# = EQ
| otherwise = loop 0
where
cmp LT _ = LT
cmp EQ y = y
cmp GT _ = GT
sz1 = PB.sizeofByteArray (ByteArray ba1#)
sz2 = PB.sizeofByteArray (ByteArray ba2#)
sz = quot (min sz1 sz2) (sizeOf (undefined :: a))
loop !i
| i < sz = compare (indexPrimArray a1 i) (indexPrimArray a2 i) `cmp` loop (i+1)
| otherwise = compare sz1 sz2
{-# INLINE compare #-}
instance (Show a, Prim a) => Show (PrimArray a) where
showsPrec p a = showParen (p > 10) $
showString "fromListN " . shows (sizeofPrimArray a) . showString " "
. shows (primArrayToList a)
{-# INLINE primArrayToList #-}
primArrayToList :: forall a. Prim a => PrimArray a -> [a]
primArrayToList xs = build (\c n -> foldrPrimArray c n xs)
newPrimArray :: forall m a. (PrimMonad m, Prim a) => Int -> m (MutablePrimArray (PrimState m) a)
{-# INLINE newPrimArray #-}
newPrimArray (I# n#)
= primitive (\s# ->
case newByteArray# (n# *# sizeOf# (undefined :: a)) s# of
(# s'#, arr# #) -> (# s'#, MutablePrimArray arr# #)
)
resizeMutablePrimArray :: forall m a. (PrimMonad m, Prim a)
=> MutablePrimArray (PrimState m) a
-> Int
-> m (MutablePrimArray (PrimState m) a)
{-# INLINE resizeMutablePrimArray #-}
resizeMutablePrimArray (MutablePrimArray arr#) (I# n#)
= primitive (\s# -> case resizeMutableByteArray# arr# (n# *# sizeOf# (undefined :: a)) s# of
(# s'#, arr'# #) -> (# s'#, MutablePrimArray arr'# #))
shrinkMutablePrimArray :: forall m a. (PrimMonad m, Prim a)
=> MutablePrimArray (PrimState m) a
-> Int
-> m ()
{-# INLINE shrinkMutablePrimArray #-}
shrinkMutablePrimArray (MutablePrimArray arr#) (I# n#)
= primitive_ (shrinkMutableByteArray# arr# (n# *# sizeOf# (undefined :: a)))
writePrimArray ::
(Prim a, PrimMonad m)
=> MutablePrimArray (PrimState m) a
-> Int
-> a
-> m ()
{-# INLINE writePrimArray #-}
writePrimArray (MutablePrimArray arr#) (I# i#) x
= primitive_ (writeByteArray# arr# i# x)
unsafeFreezePrimArray
:: PrimMonad m => MutablePrimArray (PrimState m) a -> m (PrimArray a)
{-# INLINE unsafeFreezePrimArray #-}
unsafeFreezePrimArray (MutablePrimArray arr#)
= primitive (\s# -> case unsafeFreezeByteArray# arr# s# of
(# s'#, arr'# #) -> (# s'#, PrimArray arr'# #))
indexPrimArray :: forall a. Prim a => PrimArray a -> Int -> a
{-# INLINE indexPrimArray #-}
indexPrimArray (PrimArray arr#) (I# i#) = indexByteArray# arr# i#
sizeofPrimArray :: forall a. Prim a => PrimArray a -> Int
{-# INLINE sizeofPrimArray #-}
sizeofPrimArray (PrimArray arr#) = I# (quotInt# (sizeofByteArray# arr#) (sizeOf# (undefined :: a)))
{-# INLINE foldrPrimArray #-}
foldrPrimArray :: forall a b. Prim a => (a -> b -> b) -> b -> PrimArray a -> b
foldrPrimArray f z arr = go 0
where
!sz = sizeofPrimArray arr
go !i
| sz > i = f (indexPrimArray arr i) (go (i+1))
| otherwise = z
{-# INLINE foldlPrimArray' #-}
foldlPrimArray' :: forall a b. Prim a => (b -> a -> b) -> b -> PrimArray a -> b
foldlPrimArray' f z0 arr = go 0 z0
where
!sz = sizeofPrimArray arr
go !i !acc
| i < sz = go (i + 1) (f acc (indexPrimArray arr i))
| otherwise = acc