{-# LANGUAGE CPP #-} -- for a bytestring version gate >:(
{-# LANGUAGE UnboxedTuples #-}

-- may as well export everything the interface is highly unsafe
module Bytezap.Poke where

import GHC.Exts
import Raehik.Compat.GHC.Exts.GHC908MemcpyPrimops

import GHC.Word ( Word8(W8#) )

import Data.ByteString qualified as BS
import Data.ByteString.Internal qualified as BS

import Control.Monad ( void )

import Raehik.Compat.Data.Primitive.Types

import GHC.ForeignPtr

import Control.Monad.Primitive

import Bytezap.Struct qualified as Struct

type Poke# s = Addr# -> Int# -> State# s -> (# State# s, Int# #)

-- | Poke newtype wrapper.
newtype Poke s = Poke { forall s. Poke s -> Poke# s
unPoke :: Poke# s }

-- | Sequence two 'Poke's left-to-right.
instance Semigroup (Poke s) where
    Poke Poke# s
l <> :: Poke s -> Poke s -> Poke s
<> Poke Poke# s
r = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os0# State# s
s0 ->
        case Poke# s
l Addr#
base# Int#
os0# State# s
s0 of (# State# s
s1, Int#
os1# #) -> Poke# s
r Addr#
base# Int#
os1# State# s
s1

instance Monoid (Poke s) where
    mempty :: Poke s
mempty = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
_base# Int#
os# State# s
s -> (# State# s
s, Int#
os# #)

-- | Execute a 'Poke' at a fresh 'BS.ByteString' of the given length.
unsafeRunPokeBS :: Int -> Poke RealWorld -> BS.ByteString
unsafeRunPokeBS :: Int -> Poke RealWorld -> ByteString
unsafeRunPokeBS Int
len Poke RealWorld
p = Int -> (Ptr Word8 -> IO ()) -> ByteString
BS.unsafeCreate Int
len (IO Int -> IO ()
forall (f :: Type -> Type) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> (Ptr Word8 -> IO Int) -> Ptr Word8 -> IO ()
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Poke RealWorld -> Ptr Word8 -> IO Int
forall s (m :: Type -> Type).
MonadPrim s m =>
Poke s -> Ptr Word8 -> m Int
unsafeRunPoke Poke RealWorld
p)

-- | Execute a 'Poke' at a fresh 'BS.ByteString' of the given maximum length.
--   Does not reallocate if final size is less than estimated.
unsafeRunPokeBSUptoN :: Int -> Poke RealWorld -> BS.ByteString
unsafeRunPokeBSUptoN :: Int -> Poke RealWorld -> ByteString
unsafeRunPokeBSUptoN Int
len = Int -> (Ptr Word8 -> IO Int) -> ByteString
BS.unsafeCreateUptoN Int
len ((Ptr Word8 -> IO Int) -> ByteString)
-> (Poke RealWorld -> Ptr Word8 -> IO Int)
-> Poke RealWorld
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Poke RealWorld -> Ptr Word8 -> IO Int
forall s (m :: Type -> Type).
MonadPrim s m =>
Poke s -> Ptr Word8 -> m Int
unsafeRunPoke

-- | Execute a 'Poke' at a pointer. Returns the number of bytes written.
--
-- The pointer must be a mutable buffer with enough space to hold the poke.
-- Absolutely none of this is checked. Use with caution. Sensible uses:
--
-- * implementing pokes to ByteStrings and the like
-- * executing known-length (!!) pokes to known-length (!!) buffers e.g.
--   together with allocaBytes
unsafeRunPoke :: MonadPrim s m => Poke s -> Ptr Word8 -> m Int
unsafeRunPoke :: forall s (m :: Type -> Type).
MonadPrim s m =>
Poke s -> Ptr Word8 -> m Int
unsafeRunPoke (Poke Poke# s
p) (Ptr Addr#
base#) = (State# (PrimState m) -> (# State# (PrimState m), Int #)) -> m Int
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: Type -> Type) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), Int #))
 -> m Int)
-> (State# (PrimState m) -> (# State# (PrimState m), Int #))
-> m Int
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s0 ->
    case Poke# s
p Addr#
base# Int#
0# State# s
State# (PrimState m)
s0 of (# State# s
s1, Int#
os# #) -> (# State# s
State# (PrimState m)
s1, Int# -> Int
I# Int#
os# #)

-- | Poke a type via its 'Prim'' instance.
prim :: forall a s. Prim' a => a -> Poke s
prim :: forall a s. Prim' a => a -> Poke s
prim a
a = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os# State# s
s0 ->
    case Addr# -> Int# -> a -> State# s -> State# s
forall s. Addr# -> Int# -> a -> State# s -> State# s
forall a s. Prim' a => Addr# -> Int# -> a -> State# s -> State# s
writeWord8OffAddrAs# Addr#
base# Int#
os# a
a State# s
s0 of
      State# s
s1 -> (# State# s
s1, Int#
os# Int# -> Int# -> Int#
+# a -> Int#
forall a. Prim a => a -> Int#
sizeOf# (a
forall a. HasCallStack => a
undefined :: a) #)

-- we reimplement withForeignPtr because it's too high level.
-- keepAlive# has the wrong type before GHC 9.10, but it doesn't matter here
-- because copyAddrToAddrNonOverlapping# forces RealWorld.
byteString :: BS.ByteString -> Poke RealWorld
byteString :: ByteString -> Poke RealWorld
byteString (BS.BS (ForeignPtr Addr#
p# ForeignPtrContents
r) (I# Int#
len#)) = Poke# RealWorld -> Poke RealWorld
forall s. Poke# s -> Poke s
Poke (Poke# RealWorld -> Poke RealWorld)
-> Poke# RealWorld -> Poke RealWorld
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os# State# RealWorld
s0 ->
    ForeignPtrContents
-> State# RealWorld
-> (State# RealWorld -> (# State# RealWorld, Int# #))
-> (# State# RealWorld, Int# #)
forall a b. a -> State# RealWorld -> (State# RealWorld -> b) -> b
keepAlive# ForeignPtrContents
r State# RealWorld
s0 ((State# RealWorld -> (# State# RealWorld, Int# #))
 -> (# State# RealWorld, Int# #))
-> (State# RealWorld -> (# State# RealWorld, Int# #))
-> (# State# RealWorld, Int# #)
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s1 ->
        case Addr# -> Addr# -> Int# -> State# RealWorld -> State# RealWorld
copyAddrToAddrNonOverlapping# Addr#
p# (Addr#
base# Addr# -> Int# -> Addr#
`plusAddr#` Int#
os#) Int#
len# State# RealWorld
s1 of
          State# RealWorld
s2 -> (# State# RealWorld
s2, Int#
os# Int# -> Int# -> Int#
+# Int#
len# #)

byteArray# :: ByteArray# -> Int# -> Int# -> Poke s
byteArray# :: forall s. ByteArray# -> Int# -> Int# -> Poke s
byteArray# ByteArray#
ba# Int#
baos# Int#
balen# = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os# State# s
s0 ->
    case ByteArray# -> Int# -> Addr# -> Int# -> State# s -> State# s
forall d.
ByteArray# -> Int# -> Addr# -> Int# -> State# d -> State# d
copyByteArrayToAddr# ByteArray#
ba# Int#
baos# (Addr#
base# Addr# -> Int# -> Addr#
`plusAddr#` Int#
os#) Int#
balen# State# s
s0 of
      State# s
s1 -> (# State# s
s1, Int#
os# Int# -> Int# -> Int#
+# Int#
balen# #)

-- | essentially memset
replicateByte :: Int -> Word8 -> Poke RealWorld
replicateByte :: Int -> Word8 -> Poke RealWorld
replicateByte (I# Int#
len#) (W8# Word8#
byte#) = Poke# RealWorld -> Poke RealWorld
forall s. Poke# s -> Poke s
Poke (Poke# RealWorld -> Poke RealWorld)
-> Poke# RealWorld -> Poke RealWorld
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os# State# RealWorld
s0 ->
    case Addr# -> Int# -> Int# -> State# RealWorld -> State# RealWorld
setAddrRange# (Addr#
base# Addr# -> Int# -> Addr#
`plusAddr#` Int#
os#) Int#
len# Int#
byteAsInt# State# RealWorld
s0 of
      State# RealWorld
s1 -> (# State# RealWorld
s1, Int#
os# Int# -> Int# -> Int#
+# Int#
len# #)
  where
    byteAsInt# :: Int#
byteAsInt# = Word# -> Int#
word2Int# (Word8# -> Word#
word8ToWord# Word8#
byte#)

-- | Use a struct poke as a regular poke.
--
-- To do this, we must associate a constant byte length with an existing poker.
-- Note that pokers don't expose the type of the data they are serializing,
-- so this is a very clumsy operation by itself. You should only be using this
-- when you have such types in scope, and the constant length should be obtained
-- in a sensible manner (e.g. 'Bytezap.Struct.Generic.KnownSizeOf' for generic
-- struct pokers, or your own constant size class if you're doing funky stuff).
fromStructPoke :: Int -> Struct.Poke s -> Poke s
fromStructPoke :: forall s. Int -> Poke s -> Poke s
fromStructPoke (I# Int#
len#) (Struct.Poke Poke# s
p) = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os# State# s
s ->
    (# Poke# s
p Addr#
base# Int#
os# State# s
s, Int#
os# Int# -> Int# -> Int#
+# Int#
len# #)

-- | Use a struct poke as a regular poke by throwing away the return offset.
toStructPoke :: Poke s -> Struct.Poke s
toStructPoke :: forall s. Poke s -> Poke s
toStructPoke (Poke Poke# s
p) = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Struct.Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os0# State# s
s0 ->
    case Poke# s
p Addr#
base# Int#
os0# State# s
s0 of (# State# s
s1, Int#
_os1# #) -> State# s
s1