{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}

-- |
-- Module      : Streamly.Internal.Data.Prim.Array.Types
-- Copyright   : (c) Roman Leshchinskiy 2009-2012
-- License     : BSD-style
--
-- Maintainer  : streamly@composewell.com
-- Portability : non-portable
--
-- Arrays of unboxed primitive types. The function provided by this module
-- match the behavior of those provided by @Data.Primitive.ByteArray@, and
-- the underlying types and primops that back them are the same.
-- However, the type constructors 'PrimArray' and 'MutablePrimArray' take one additional
-- argument than their respective counterparts 'ByteArray' and 'MutableByteArray'.
-- This argument is used to designate the type of element in the array.
-- Consequently, all function this modules accepts length and incides in
-- terms of elements, not bytes.
--
-- @since 0.6.4.0
module Streamly.Internal.Data.Prim.Array.Types
  ( -- * Types
    PrimArray(..)
  , MutablePrimArray(..)
    -- * Allocation
  , newPrimArray
  , resizeMutablePrimArray
  , shrinkMutablePrimArray
    -- * Element Access
  , writePrimArray
  , indexPrimArray
    -- * Freezing and Thawing
  , unsafeFreezePrimArray
    -- * Information
  , sizeofPrimArray
    -- * Folding
  , foldrPrimArray
  , foldlPrimArray'
  ) where

import GHC.Exts

import Data.Primitive.Types
import Data.Primitive.ByteArray (ByteArray(..))
import Control.Monad.Primitive
import qualified Data.Primitive.ByteArray as PB

-- | Arrays of unboxed elements. This accepts types like 'Double', 'Char',
-- 'Int', and 'Word', as well as their fixed-length variants ('Word8',
-- 'Word16', etc.). Since the elements are unboxed, a 'PrimArray' is strict
-- in its elements. This differs from the behavior of 'Array', which is lazy
-- in its elements.
data PrimArray a = PrimArray ByteArray#

-- | Mutable primitive arrays associated with a primitive state token.
-- These can be written to and read from in a monadic context that supports
-- sequencing such as 'IO' or 'ST'. Typically, a mutable primitive array will
-- be built and then convert to an immutable primitive array using
-- 'unsafeFreezePrimArray'. However, it is also acceptable to simply discard
-- a mutable primitive array since it lives in managed memory and will be
-- garbage collected when no longer referenced.
data MutablePrimArray s a = MutablePrimArray (MutableByteArray# s)

sameByteArray :: ByteArray# -> ByteArray# -> Bool
sameByteArray :: ByteArray# -> ByteArray# -> Bool
sameByteArray ByteArray#
ba1 ByteArray#
ba2 =
    case () -> () -> Int#
forall a. a -> a -> Int#
reallyUnsafePtrEquality# (ByteArray# -> ()
unsafeCoerce# ByteArray#
ba1 :: ()) (ByteArray# -> ()
unsafeCoerce# ByteArray#
ba2 :: ()) of
      Int#
r -> Int# -> Bool
isTrue# Int#
r

-- | @since 0.6.4.0
instance (Eq a, Prim a) => Eq (PrimArray a) where
  a1 :: PrimArray a
a1@(PrimArray ByteArray#
ba1#) == :: PrimArray a -> PrimArray a -> Bool
== a2 :: PrimArray a
a2@(PrimArray ByteArray#
ba2#)
    | ByteArray# -> ByteArray# -> Bool
sameByteArray ByteArray#
ba1# ByteArray#
ba2# = Bool
True
    | Int
sz1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
sz2 = Bool
False
    | Bool
otherwise = Int -> Bool
loop (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
sz1 (a -> Int
forall a. Prim a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a)) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    where
    -- Here, we take the size in bytes, not in elements. We do this
    -- since it allows us to defer performing the division to
    -- calculate the size in elements.
    sz1 :: Int
sz1 = ByteArray -> Int
PB.sizeofByteArray (ByteArray# -> ByteArray
ByteArray ByteArray#
ba1#)
    sz2 :: Int
sz2 = ByteArray -> Int
PB.sizeofByteArray (ByteArray# -> ByteArray
ByteArray ByteArray#
ba2#)
    loop :: Int -> Bool
loop !Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Bool
True
      | Bool
otherwise = PrimArray a -> Int -> a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray a
a1 Int
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== PrimArray a -> Int -> a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray a
a2 Int
i Bool -> Bool -> Bool
&& Int -> Bool
loop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
  {-# INLINE (==) #-}

-- | Lexicographic ordering. Subject to change between major versions.
--
--   @since 0.6.4.0
instance (Ord a, Prim a) => Ord (PrimArray a) where
  compare :: PrimArray a -> PrimArray a -> Ordering
compare a1 :: PrimArray a
a1@(PrimArray ByteArray#
ba1#) a2 :: PrimArray a
a2@(PrimArray ByteArray#
ba2#)
    | ByteArray# -> ByteArray# -> Bool
sameByteArray ByteArray#
ba1# ByteArray#
ba2# = Ordering
EQ
    | Bool
otherwise = Int -> Ordering
loop Int
0
    where
    cmp :: Ordering -> Ordering -> Ordering
cmp Ordering
LT Ordering
_ = Ordering
LT
    cmp Ordering
EQ Ordering
y = Ordering
y
    cmp Ordering
GT Ordering
_ = Ordering
GT
    sz1 :: Int
sz1 = ByteArray -> Int
PB.sizeofByteArray (ByteArray# -> ByteArray
ByteArray ByteArray#
ba1#)
    sz2 :: Int
sz2 = ByteArray -> Int
PB.sizeofByteArray (ByteArray# -> ByteArray
ByteArray ByteArray#
ba2#)
    sz :: Int
sz = Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
sz1 Int
sz2) (a -> Int
forall a. Prim a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
    loop :: Int -> Ordering
loop !Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
sz = a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (PrimArray a -> Int -> a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray a
a1 Int
i) (PrimArray a -> Int -> a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray a
a2 Int
i) Ordering -> Ordering -> Ordering
`cmp` Int -> Ordering
loop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
      | Bool
otherwise = Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
sz1 Int
sz2
  {-# INLINE compare #-}

-- | @since 0.6.4.0
instance (Show a, Prim a) => Show (PrimArray a) where
  showsPrec :: Int -> PrimArray a -> ShowS
showsPrec Int
p PrimArray a
a = Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
    String -> ShowS
showString String
"fromListN " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ShowS
forall a. Show a => a -> ShowS
shows (PrimArray a -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray a
a) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" "
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> ShowS
forall a. Show a => a -> ShowS
shows (PrimArray a -> [a]
forall a. Prim a => PrimArray a -> [a]
primArrayToList PrimArray a
a)

-- | Convert the primitive array to a list.
{-# INLINE primArrayToList #-}
primArrayToList :: forall a. Prim a => PrimArray a -> [a]
primArrayToList :: PrimArray a -> [a]
primArrayToList PrimArray a
xs = (forall b. (a -> b -> b) -> b -> b) -> [a]
forall a. (forall b. (a -> b -> b) -> b -> b) -> [a]
build (\a -> b -> b
c b
n -> (a -> b -> b) -> b -> PrimArray a -> b
forall a b. Prim a => (a -> b -> b) -> b -> PrimArray a -> b
foldrPrimArray a -> b -> b
c b
n PrimArray a
xs)

-- | Create a new mutable primitive array of the given length. The
-- underlying memory is left uninitialized.
newPrimArray :: forall m a. (PrimMonad m, Prim a) => Int -> m (MutablePrimArray (PrimState m) a)
{-# INLINE newPrimArray #-}
newPrimArray :: Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray (I# Int#
n#)
  = (State# (PrimState m)
 -> (# State# (PrimState m), MutablePrimArray (PrimState m) a #))
-> m (MutablePrimArray (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive (\State# (PrimState m)
s# ->
      case Int#
-> State# (PrimState m)
-> (# State# (PrimState m), MutableByteArray# (PrimState m) #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# (Int#
n# Int# -> Int# -> Int#
*# a -> Int#
forall a. Prim a => a -> Int#
sizeOf# (a
forall a. HasCallStack => a
undefined :: a)) State# (PrimState m)
s# of
        (# State# (PrimState m)
s'#, MutableByteArray# (PrimState m)
arr# #) -> (# State# (PrimState m)
s'#, MutableByteArray# (PrimState m) -> MutablePrimArray (PrimState m) a
forall s a. MutableByteArray# s -> MutablePrimArray s a
MutablePrimArray MutableByteArray# (PrimState m)
arr# #)
    )

-- | Resize a mutable primitive array. The new size is given in elements.
--
-- This will either resize the array in-place or, if not possible, allocate the
-- contents into a new, unpinned array and copy the original array\'s contents.
--
-- To avoid undefined behaviour, the original 'MutablePrimArray' shall not be
-- accessed anymore after a 'resizeMutablePrimArray' has been performed.
-- Moreover, no reference to the old one should be kept in order to allow
-- garbage collection of the original 'MutablePrimArray' in case a new
-- 'MutablePrimArray' had to be allocated.
resizeMutablePrimArray :: forall m a. (PrimMonad m, Prim a)
  => MutablePrimArray (PrimState m) a
  -> Int -- ^ new size
  -> m (MutablePrimArray (PrimState m) a)
{-# INLINE resizeMutablePrimArray #-}
resizeMutablePrimArray :: MutablePrimArray (PrimState m) a
-> Int -> m (MutablePrimArray (PrimState m) a)
resizeMutablePrimArray (MutablePrimArray MutableByteArray# (PrimState m)
arr#) (I# Int#
n#)
  = (State# (PrimState m)
 -> (# State# (PrimState m), MutablePrimArray (PrimState m) a #))
-> m (MutablePrimArray (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive (\State# (PrimState m)
s# -> case MutableByteArray# (PrimState m)
-> Int#
-> State# (PrimState m)
-> (# State# (PrimState m), MutableByteArray# (PrimState m) #)
forall d.
MutableByteArray# d
-> Int# -> State# d -> (# State# d, MutableByteArray# d #)
resizeMutableByteArray# MutableByteArray# (PrimState m)
arr# (Int#
n# Int# -> Int# -> Int#
*# a -> Int#
forall a. Prim a => a -> Int#
sizeOf# (a
forall a. HasCallStack => a
undefined :: a)) State# (PrimState m)
s# of
                        (# State# (PrimState m)
s'#, MutableByteArray# (PrimState m)
arr'# #) -> (# State# (PrimState m)
s'#, MutableByteArray# (PrimState m) -> MutablePrimArray (PrimState m) a
forall s a. MutableByteArray# s -> MutablePrimArray s a
MutablePrimArray MutableByteArray# (PrimState m)
arr'# #))

-- Although it is possible to shim resizeMutableByteArray for old GHCs, this
-- is not the case with shrinkMutablePrimArray.

-- | Shrink a mutable primitive array. The new size is given in elements.
-- It must be smaller than the old size. The array will be resized in place.
-- This function is only available when compiling with GHC 7.10 or newer.
shrinkMutablePrimArray :: forall m a. (PrimMonad m, Prim a)
  => MutablePrimArray (PrimState m) a
  -> Int -- ^ new size
  -> m ()
{-# INLINE shrinkMutablePrimArray #-}
shrinkMutablePrimArray :: MutablePrimArray (PrimState m) a -> Int -> m ()
shrinkMutablePrimArray (MutablePrimArray MutableByteArray# (PrimState m)
arr#) (I# Int#
n#)
  = (State# (PrimState m) -> State# (PrimState m)) -> m ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ (MutableByteArray# (PrimState m)
-> Int# -> State# (PrimState m) -> State# (PrimState m)
forall d. MutableByteArray# d -> Int# -> State# d -> State# d
shrinkMutableByteArray# MutableByteArray# (PrimState m)
arr# (Int#
n# Int# -> Int# -> Int#
*# a -> Int#
forall a. Prim a => a -> Int#
sizeOf# (a
forall a. HasCallStack => a
undefined :: a)))

-- | Write an element to the given index.
writePrimArray ::
     (Prim a, PrimMonad m)
  => MutablePrimArray (PrimState m) a -- ^ array
  -> Int -- ^ index
  -> a -- ^ element
  -> m ()
{-# INLINE writePrimArray #-}
writePrimArray :: MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray (MutablePrimArray MutableByteArray# (PrimState m)
arr#) (I# Int#
i#) a
x
  = (State# (PrimState m) -> State# (PrimState m)) -> m ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ (MutableByteArray# (PrimState m)
-> Int# -> a -> State# (PrimState m) -> State# (PrimState m)
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> a -> State# s -> State# s
writeByteArray# MutableByteArray# (PrimState m)
arr# Int#
i# a
x)

-- | Convert a mutable byte array to an immutable one without copying. The
-- array should not be modified after the conversion.
unsafeFreezePrimArray
  :: PrimMonad m => MutablePrimArray (PrimState m) a -> m (PrimArray a)
{-# INLINE unsafeFreezePrimArray #-}
unsafeFreezePrimArray :: MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray (MutablePrimArray MutableByteArray# (PrimState m)
arr#)
  = (State# (PrimState m) -> (# State# (PrimState m), PrimArray a #))
-> m (PrimArray a)
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive (\State# (PrimState m)
s# -> case MutableByteArray# (PrimState m)
-> State# (PrimState m) -> (# State# (PrimState m), ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# (PrimState m)
arr# State# (PrimState m)
s# of
                        (# State# (PrimState m)
s'#, ByteArray#
arr'# #) -> (# State# (PrimState m)
s'#, ByteArray# -> PrimArray a
forall a. ByteArray# -> PrimArray a
PrimArray ByteArray#
arr'# #))

-- | Read a primitive value from the primitive array.
indexPrimArray :: forall a. Prim a => PrimArray a -> Int -> a
{-# INLINE indexPrimArray #-}
indexPrimArray :: PrimArray a -> Int -> a
indexPrimArray (PrimArray ByteArray#
arr#) (I# Int#
i#) = ByteArray# -> Int# -> a
forall a. Prim a => ByteArray# -> Int# -> a
indexByteArray# ByteArray#
arr# Int#
i#

-- | Get the size, in elements, of the primitive array.
sizeofPrimArray :: forall a. Prim a => PrimArray a -> Int
{-# INLINE sizeofPrimArray #-}
sizeofPrimArray :: PrimArray a -> Int
sizeofPrimArray (PrimArray ByteArray#
arr#) = Int# -> Int
I# (Int# -> Int# -> Int#
quotInt# (ByteArray# -> Int#
sizeofByteArray# ByteArray#
arr#) (a -> Int#
forall a. Prim a => a -> Int#
sizeOf# (a
forall a. HasCallStack => a
undefined :: a)))

-- | Lazy right-associated fold over the elements of a 'PrimArray'.
{-# INLINE foldrPrimArray #-}
foldrPrimArray :: forall a b. Prim a => (a -> b -> b) -> b -> PrimArray a -> b
foldrPrimArray :: (a -> b -> b) -> b -> PrimArray a -> b
foldrPrimArray a -> b -> b
f b
z PrimArray a
arr = Int -> b
go Int
0
  where
    !sz :: Int
sz = PrimArray a -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray a
arr
    go :: Int -> b
go !Int
i
      | Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
i = a -> b -> b
f (PrimArray a -> Int -> a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray a
arr Int
i) (Int -> b
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
      | Bool
otherwise = b
z

-- | Strict left-associated fold over the elements of a 'PrimArray'.
{-# INLINE foldlPrimArray' #-}
foldlPrimArray' :: forall a b. Prim a => (b -> a -> b) -> b -> PrimArray a -> b
foldlPrimArray' :: (b -> a -> b) -> b -> PrimArray a -> b
foldlPrimArray' b -> a -> b
f b
z0 PrimArray a
arr = Int -> b -> b
go Int
0 b
z0
  where
    !sz :: Int
sz = PrimArray a -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray a
arr
    go :: Int -> b -> b
go !Int
i !b
acc
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
sz = Int -> b -> b
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (b -> a -> b
f b
acc (PrimArray a -> Int -> a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray a
arr Int
i))
      | Bool
otherwise = b
acc