{-# LANGUAGE UnboxedTuples #-}

#include "inline.hs"

-- |
-- Module      : Streamly.Internal.Data.Array.Prim.Pinned.Mut.Type
-- Copyright   : (c) 2020 Composewell Technologies
--
-- License     : BSD-3-Clause
-- Maintainer  : streamly@composewell.com
-- Stability   : experimental
-- Portability : GHC
--
module Streamly.Internal.Data.Array.Prim.Pinned.Mut.Type
    (
      Array (..)

    -- * Construction
    , newArray
    , newAlignedArray
    , unsafeWriteIndex

    , spliceTwo
    , unsafeCopy

    , fromListM
    , fromListNM
    , fromStreamDN
    , fromStreamD

    -- * Streams of arrays
    , fromStreamDArraysOf

    , packArraysChunksOf
    , lpackArraysChunksOf

#if !defined(mingw32_HOST_OS)
--    , groupIOVecsOf
#endif

    -- * Elimination
    , unsafeReadIndex
    , length
    , byteLength

    , writeN
    , ArrayUnsafe(..)
    , writeNUnsafe
    , writeNAligned
    , write

    -- * Utilities
    , resizeArray
    , shrinkArray

    , touchArray
    , withArrayAsPtr
    )
where

import GHC.IO (IO(..))

#include "Streamly/Internal/Data/Array/Prim/MutTypesInclude.hs"

-------------------------------------------------------------------------------
-- Allocation (Pinned)
-------------------------------------------------------------------------------

-- XXX we can use a single newArray routine which accepts an allocation
-- function which could be newByteArray#, newPinnedByteArray# or
-- newAlignedPinnedByteArray#. That function can go in the common include file.
--
-- | Allocate an array that is pinned and can hold 'count' items.  The memory of
-- the array is uninitialized.
--
-- Note that this is internal routine, the reference to this array cannot be
-- given out until the array has been written to and frozen.
{-# INLINE newArray #-}
newArray ::
       forall m a. (MonadIO m, Prim a)
    => Int
    -> m (Array a)
newArray :: Int -> m (Array a)
newArray (I# Int#
n#) =
    IO (Array a) -> m (Array a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Array a) -> m (Array a)) -> IO (Array a) -> m (Array a)
forall a b. (a -> b) -> a -> b
$ do
        let bytes :: Int#
bytes = Int#
n# Int# -> Int# -> Int#
*# a -> Int#
forall a. Prim a => a -> Int#
sizeOf# (a
forall a. HasCallStack => a
undefined :: a)
        (State# (PrimState IO) -> (# State# (PrimState IO), Array a #))
-> IO (Array a)
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState IO) -> (# State# (PrimState IO), Array a #))
 -> IO (Array a))
-> (State# (PrimState IO) -> (# State# (PrimState IO), Array a #))
-> IO (Array a)
forall a b. (a -> b) -> a -> b
$ \State# (PrimState IO)
s# ->
            case Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newPinnedByteArray# Int#
bytes State# RealWorld
State# (PrimState IO)
s# of
                (# State# RealWorld
s1#, MutableByteArray# RealWorld
arr# #) -> (# State# RealWorld
State# (PrimState IO)
s1#, MutableByteArray# RealWorld -> Array a
forall a. MutableByteArray# RealWorld -> Array a
Array MutableByteArray# RealWorld
arr# #)

-- Change order of args?
-- | Allocate a new array aligned to the specified alignment and using pinned
-- memory.
{-# INLINE newAlignedArray #-}
newAlignedArray ::
       forall m a. (MonadIO m, Prim a)
    => Int -- size
    -> Int -- Alignment
    -> m (Array a)
newAlignedArray :: Int -> Int -> m (Array a)
newAlignedArray (I# Int#
n#) (I# Int#
a#) =
    IO (Array a) -> m (Array a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Array a) -> m (Array a)) -> IO (Array a) -> m (Array a)
forall a b. (a -> b) -> a -> b
$ do
        let bytes :: Int#
bytes = Int#
n# Int# -> Int# -> Int#
*# a -> Int#
forall a. Prim a => a -> Int#
sizeOf# (a
forall a. HasCallStack => a
undefined :: a)
        (State# (PrimState IO) -> (# State# (PrimState IO), Array a #))
-> IO (Array a)
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState IO) -> (# State# (PrimState IO), Array a #))
 -> IO (Array a))
-> (State# (PrimState IO) -> (# State# (PrimState IO), Array a #))
-> IO (Array a)
forall a b. (a -> b) -> a -> b
$ \State# (PrimState IO)
s# ->
            case Int#
-> Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d.
Int# -> Int# -> State# d -> (# State# d, MutableByteArray# d #)
newAlignedPinnedByteArray# Int#
bytes Int#
a# State# RealWorld
State# (PrimState IO)
s# of
                (# State# RealWorld
s1#, MutableByteArray# RealWorld
arr# #) -> (# State# RealWorld
State# (PrimState IO)
s1#, MutableByteArray# RealWorld -> Array a
forall a. MutableByteArray# RealWorld -> Array a
Array MutableByteArray# RealWorld
arr# #)

-- | Resize (pinned) mutable byte array to new specified size (in elem
-- count). The returned array is either the original array resized in-place or,
-- if not possible, a newly allocated (pinned) array (with the original content
-- copied over).
{-# INLINE resizeArray #-}
resizeArray ::
       (MonadIO m, Prim a)
    => Array a
    -> Int -- ^ new size
    -> m (Array a)
resizeArray :: Array a -> Int -> m (Array a)
resizeArray Array a
arr Int
i = do
    Int
len <- Array a -> m Int
forall (m :: * -> *) a. (MonadIO m, Prim a) => Array a -> m Int
length Array a
arr
    if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i
    then Array a -> m (Array a)
forall (m :: * -> *) a. Monad m => a -> m a
return Array a
arr
    else if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len
         then Array a -> Int -> m ()
forall (m :: * -> *) a.
(MonadIO m, Prim a) =>
Array a -> Int -> m ()
shrinkArray Array a
arr Int
i m () -> m (Array a) -> m (Array a)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Array a -> m (Array a)
forall (m :: * -> *) a. Monad m => a -> m a
return Array a
arr
         else do
             Array a
nArr <- Int -> m (Array a)
forall (m :: * -> *) a. (MonadIO m, Prim a) => Int -> m (Array a)
newArray Int
i
             Array a -> Int -> Array a -> Int -> Int -> m ()
forall (m :: * -> *) a.
(MonadIO m, Prim a) =>
Array a -> Int -> Array a -> Int -> Int -> m ()
unsafeCopy Array a
nArr Int
0 Array a
arr Int
0 Int
len
             Array a -> m (Array a)
forall (m :: * -> *) a. Monad m => a -> m a
return Array a
nArr

-------------------------------------------------------------------------------
-- Aligned Construction
-------------------------------------------------------------------------------

-- XXX we can also factor out common code in writeN and writeNAligned in the
-- same way as suggested above.
--
{-# INLINE_NORMAL writeNAligned #-}
writeNAligned ::
       (MonadIO m, Prim a)
    => Int
    -> Int
    -> Fold m a (Array a)
writeNAligned :: Int -> Int -> Fold m a (Array a)
writeNAligned Int
align Int
limit = (Tuple' (Array a) Int
 -> a -> m (Step (Tuple' (Array a) Int) (Array a)))
-> m (Step (Tuple' (Array a) Int) (Array a))
-> (Tuple' (Array a) Int -> m (Array a))
-> Fold m a (Array a)
forall (m :: * -> *) a b s.
(s -> a -> m (Step s b))
-> m (Step s b) -> (s -> m b) -> Fold m a b
Fold Tuple' (Array a) Int
-> a -> m (Step (Tuple' (Array a) Int) (Array a))
forall (m :: * -> *) a.
(MonadIO m, Prim a) =>
Tuple' (Array a) Int
-> a -> m (Step (Tuple' (Array a) Int) (Array a))
step m (Step (Tuple' (Array a) Int) (Array a))
forall b. m (Step (Tuple' (Array a) Int) b)
initial Tuple' (Array a) Int -> m (Array a)
forall (m :: * -> *) a.
(MonadIO m, Prim a) =>
Tuple' (Array a) Int -> m (Array a)
extract

    where

    initial :: m (Step (Tuple' (Array a) Int) b)
initial = do
        Array a
marr <- Int -> Int -> m (Array a)
forall (m :: * -> *) a.
(MonadIO m, Prim a) =>
Int -> Int -> m (Array a)
newAlignedArray Int
limit Int
align
        Step (Tuple' (Array a) Int) b -> m (Step (Tuple' (Array a) Int) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Step (Tuple' (Array a) Int) b
 -> m (Step (Tuple' (Array a) Int) b))
-> Step (Tuple' (Array a) Int) b
-> m (Step (Tuple' (Array a) Int) b)
forall a b. (a -> b) -> a -> b
$ Tuple' (Array a) Int -> Step (Tuple' (Array a) Int) b
forall s b. s -> Step s b
FL.Partial (Tuple' (Array a) Int -> Step (Tuple' (Array a) Int) b)
-> Tuple' (Array a) Int -> Step (Tuple' (Array a) Int) b
forall a b. (a -> b) -> a -> b
$ Array a -> Int -> Tuple' (Array a) Int
forall a b. a -> b -> Tuple' a b
Tuple' Array a
marr Int
0

    extract :: Tuple' (Array a) Int -> m (Array a)
extract (Tuple' Array a
marr Int
len) = Array a -> Int -> m ()
forall (m :: * -> *) a.
(MonadIO m, Prim a) =>
Array a -> Int -> m ()
shrinkArray Array a
marr Int
len m () -> m (Array a) -> m (Array a)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Array a -> m (Array a)
forall (m :: * -> *) a. Monad m => a -> m a
return Array a
marr

    step :: Tuple' (Array a) Int
-> a -> f (Step (Tuple' (Array a) Int) (Array a))
step s :: Tuple' (Array a) Int
s@(Tuple' Array a
marr Int
i) a
x
        | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
limit = Array a -> Step (Tuple' (Array a) Int) (Array a)
forall s b. b -> Step s b
FL.Done (Array a -> Step (Tuple' (Array a) Int) (Array a))
-> f (Array a) -> f (Step (Tuple' (Array a) Int) (Array a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tuple' (Array a) Int -> f (Array a)
forall (m :: * -> *) a.
(MonadIO m, Prim a) =>
Tuple' (Array a) Int -> m (Array a)
extract Tuple' (Array a) Int
s
        | Bool
otherwise = do
            Array a -> Int -> a -> f ()
forall (m :: * -> *) a.
(MonadIO m, Prim a) =>
Array a -> Int -> a -> m ()
unsafeWriteIndex Array a
marr Int
i a
x
            Step (Tuple' (Array a) Int) (Array a)
-> f (Step (Tuple' (Array a) Int) (Array a))
forall (m :: * -> *) a. Monad m => a -> m a
return (Step (Tuple' (Array a) Int) (Array a)
 -> f (Step (Tuple' (Array a) Int) (Array a)))
-> Step (Tuple' (Array a) Int) (Array a)
-> f (Step (Tuple' (Array a) Int) (Array a))
forall a b. (a -> b) -> a -> b
$ Tuple' (Array a) Int -> Step (Tuple' (Array a) Int) (Array a)
forall s b. s -> Step s b
FL.Partial (Tuple' (Array a) Int -> Step (Tuple' (Array a) Int) (Array a))
-> Tuple' (Array a) Int -> Step (Tuple' (Array a) Int) (Array a)
forall a b. (a -> b) -> a -> b
$ Array a -> Int -> Tuple' (Array a) Int
forall a b. a -> b -> Tuple' a b
Tuple' Array a
marr (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-------------------------------------------------------------------------------
-- Mutation with pointers
-------------------------------------------------------------------------------

-- XXX This section can probably go in a common include file for pinned arrays.

-- Change name later.
{-# INLINE toPtr #-}
toPtr :: Array a -> Ptr a
toPtr :: Array a -> Ptr a
toPtr (Array MutableByteArray# RealWorld
arr#) = Addr# -> Ptr a
forall a. Addr# -> Ptr a
Ptr (ByteArray# -> Addr#
byteArrayContents# (MutableByteArray# RealWorld -> ByteArray#
unsafeCoerce# MutableByteArray# RealWorld
arr#))

{-# INLINE touchArray #-}
touchArray :: Array a -> IO ()
touchArray :: Array a -> IO ()
touchArray Array a
arr = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s -> case Array a -> State# RealWorld -> State# RealWorld
forall a. a -> State# RealWorld -> State# RealWorld
touch# Array a
arr State# RealWorld
s of State# RealWorld
s' -> (# State# RealWorld
s', () #)

{-# INLINE withArrayAsPtr #-}
withArrayAsPtr :: Array a -> (Ptr a -> IO b) -> IO b
withArrayAsPtr :: Array a -> (Ptr a -> IO b) -> IO b
withArrayAsPtr Array a
arr Ptr a -> IO b
f = do
    b
r <- Ptr a -> IO b
f (Array a -> Ptr a
forall a. Array a -> Ptr a
toPtr Array a
arr)
    Array a -> IO ()
forall a. Array a -> IO ()
touchArray Array a
arr
    b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r