{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UnboxedTuples #-}

module Foreign.C.String.Managed
  ( ManagedCString (..)
  , terminated
  , terminatedU
  , unterminated
  , fromBytes
  , fromLatinString
  , fromShortText
  , pinnedFromBytes
  , pin
  , touch
  , contents
  ) where

import Control.Monad.ST (ST)
import Control.Monad.ST.Run (runByteArrayST)
import Data.Bytes.Types (Bytes (Bytes))
import Data.Char (ord)
import Data.Primitive (ByteArray (..), MutableByteArray)
import Data.Text.Short (ShortText)
import Data.Word (Word8)
import Foreign.C.String (CString)
import Foreign.Ptr (castPtr)
import GHC.Exts (ByteArray#, Char (C#), Int (I#), chr#, touch#)
import GHC.IO (IO (IO))

import qualified Data.Bytes as Bytes
import qualified Data.Bytes.Text.Utf8 as Utf8
import qualified Data.Primitive as PM
import qualified GHC.Exts as Exts

-- | An unsliced byte sequence with @NUL@ as the final byte.
newtype ManagedCString = ManagedCString ByteArray
  deriving newtype (ManagedCString -> ManagedCString -> Bool
(ManagedCString -> ManagedCString -> Bool)
-> (ManagedCString -> ManagedCString -> Bool) -> Eq ManagedCString
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ManagedCString -> ManagedCString -> Bool
== :: ManagedCString -> ManagedCString -> Bool
$c/= :: ManagedCString -> ManagedCString -> Bool
/= :: ManagedCString -> ManagedCString -> Bool
Eq)

instance Semigroup ManagedCString where
  ManagedCString ByteArray
a <> :: ManagedCString -> ManagedCString -> ManagedCString
<> ManagedCString ByteArray
b = ByteArray -> ManagedCString
ManagedCString (ByteArray -> ManagedCString) -> ByteArray -> ManagedCString
forall a b. (a -> b) -> a -> b
$ (forall s. ST s ByteArray) -> ByteArray
runByteArrayST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
    let lenA :: Int
lenA = ByteArray -> Int
PM.sizeofByteArray ByteArray
a
    let lenB :: Int
lenB = ByteArray -> Int
PM.sizeofByteArray ByteArray
b
    MutableByteArray s
dst <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (Int
lenA Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lenB Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
PM.copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
0 ByteArray
a Int
0 (Int
lenA Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
PM.copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
lenA Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ByteArray
b Int
0 Int
lenB
    MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst

instance Monoid ManagedCString where
  mempty :: ManagedCString
mempty = ByteArray -> ManagedCString
ManagedCString (ByteArray -> ManagedCString) -> ByteArray -> ManagedCString
forall a b. (a -> b) -> a -> b
$ (forall s. ST s ByteArray) -> ByteArray
runByteArrayST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
    MutableByteArray s
dst <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
1
    MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
0 (Word8
0 :: Word8)
    MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst

instance Exts.IsString ManagedCString where
  fromString :: String -> ManagedCString
fromString = String -> ManagedCString
fromLatinString

instance Show ManagedCString where
  showsPrec :: Int -> ManagedCString -> ShowS
showsPrec Int
_ (ManagedCString ByteArray
arr) String
s0 =
    (Word8 -> ShowS) -> String -> ByteArray -> String
forall a b. Prim a => (a -> b -> b) -> b -> ByteArray -> b
PM.foldrByteArray
      ( \(Word8
w :: Word8) String
s ->
          if
            | Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0 -> String
s
            | Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
< Word8
32 -> Char
'?' Char -> ShowS
forall a. a -> [a] -> [a]
: String
s
            | Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
> Word8
126 -> Char
'?' Char -> ShowS
forall a. a -> [a] -> [a]
: String
s
            | Bool
otherwise -> case forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word8 @Int Word8
w of
                I# Int#
i -> Char# -> Char
C# (Int# -> Char#
chr# Int#
i) Char -> ShowS
forall a. a -> [a] -> [a]
: String
s
      )
      String
s0
      ByteArray
arr

terminatedU :: ManagedCString -> ByteArray
terminatedU :: ManagedCString -> ByteArray
terminatedU (ManagedCString ByteArray
x) = ByteArray
x

terminated :: ManagedCString -> Bytes
terminated :: ManagedCString -> Bytes
terminated (ManagedCString ByteArray
x) = ByteArray -> Bytes
Bytes.fromByteArray ByteArray
x

unterminated :: ManagedCString -> Bytes
unterminated :: ManagedCString -> Bytes
unterminated (ManagedCString ByteArray
x) = ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
x Int
0 (ByteArray -> Int
PM.sizeofByteArray ByteArray
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

fromShortText :: ShortText -> ManagedCString
fromShortText :: ShortText -> ManagedCString
fromShortText !ShortText
ts = Bytes -> ManagedCString
fromBytes (ShortText -> Bytes
Utf8.fromShortText ShortText
ts)

-- | Copies the slice, appending a @NUL@ byte to the end.
fromBytes :: Bytes -> ManagedCString
fromBytes :: Bytes -> ManagedCString
fromBytes (Bytes ByteArray
arr Int
off Int
len) = ByteArray -> ManagedCString
ManagedCString (ByteArray -> ManagedCString) -> ByteArray -> ManagedCString
forall a b. (a -> b) -> a -> b
$ (forall s. ST s ByteArray) -> ByteArray
runByteArrayST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
  MutableByteArray s
dst <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
PM.copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
0 ByteArray
arr Int
off Int
len
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
len (Word8
0 :: Word8)
  MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst

-- | Copies the slice into pinned memory, appending a @NUL@ byte to the end.
pinnedFromBytes :: Bytes -> ManagedCString
pinnedFromBytes :: Bytes -> ManagedCString
pinnedFromBytes (Bytes ByteArray
arr Int
off Int
len) = ByteArray -> ManagedCString
ManagedCString (ByteArray -> ManagedCString) -> ByteArray -> ManagedCString
forall a b. (a -> b) -> a -> b
$ (forall s. ST s ByteArray) -> ByteArray
runByteArrayST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
  MutableByteArray s
dst <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
PM.copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
0 ByteArray
arr Int
off Int
len
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
len (Word8
0 :: Word8)
  MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst

pin :: ManagedCString -> ManagedCString
pin :: ManagedCString -> ManagedCString
pin (ManagedCString ByteArray
x) =
  if ByteArray -> Bool
PM.isByteArrayPinned ByteArray
x
    then ByteArray -> ManagedCString
ManagedCString ByteArray
x
    else ByteArray -> ManagedCString
ManagedCString (ByteArray -> ManagedCString) -> ByteArray -> ManagedCString
forall a b. (a -> b) -> a -> b
$ (forall s. ST s ByteArray) -> ByteArray
runByteArrayST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
      let len :: Int
len = ByteArray -> Int
PM.sizeofByteArray ByteArray
x
      MutableByteArray s
dst <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
len
      MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
PM.copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
0 ByteArray
x Int
0 Int
len
      MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst

touch :: ManagedCString -> IO ()
touch :: ManagedCString -> IO ()
touch (ManagedCString (ByteArray ByteArray#
x)) = ByteArray# -> IO ()
touchByteArray# ByteArray#
x

touchByteArray# :: ByteArray# -> IO ()
touchByteArray# :: ByteArray# -> IO ()
touchByteArray# ByteArray#
x = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s -> case ByteArray# -> State# RealWorld -> State# RealWorld
forall a. a -> State# RealWorld -> State# RealWorld
touch# ByteArray#
x State# RealWorld
s of State# RealWorld
s' -> (# State# RealWorld
s', () #)

{- | Convert a 'String' consisting of only characters representable
by ISO-8859-1. These are encoded with ISO-8859-1. Any character
with a codepoint above @U+00FF@ is replaced by an unspecified byte.
-}
fromLatinString :: String -> ManagedCString
{-# NOINLINE fromLatinString #-}
fromLatinString :: String -> ManagedCString
fromLatinString String
str = ByteArray -> ManagedCString
ManagedCString (ByteArray -> ManagedCString) -> ByteArray -> ManagedCString
forall a b. (a -> b) -> a -> b
$ (forall s. ST s ByteArray) -> ByteArray
runByteArrayST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
  let lenPred0 :: Int
lenPred0 = Int
63
  MutableByteArray s
dst0 <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (Int
lenPred0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  String -> MutableByteArray s -> Int -> Int -> ST s ByteArray
forall s.
String -> MutableByteArray s -> Int -> Int -> ST s ByteArray
go String
str MutableByteArray s
dst0 Int
0 Int
lenPred0
 where
  go :: forall s. String -> MutableByteArray s -> Int -> Int -> ST s ByteArray
  go :: forall s.
String -> MutableByteArray s -> Int -> Int -> ST s ByteArray
go [] !MutableByteArray s
dst !Int
ix !Int
_ = do
    MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
ix (Word8
0 :: Word8)
    MutableByteArray (PrimState (ST s))
-> Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> m (MutableByteArray (PrimState m))
PM.resizeMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ST s (MutableByteArray s)
-> (MutableByteArray s -> ST s ByteArray) -> ST s ByteArray
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutableByteArray s -> ST s ByteArray
MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray
  go (Char
c : String
cs) !MutableByteArray s
dst !Int
ix !Int
lenPred =
    if Int
ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
lenPred
      then do
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
ix (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Word8 (Char -> Int
ord Char
c))
        String -> MutableByteArray s -> Int -> Int -> ST s ByteArray
forall s.
String -> MutableByteArray s -> Int -> Int -> ST s ByteArray
go String
cs MutableByteArray s
dst (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
lenPred
      else do
        let nextLenPred :: Int
nextLenPred = Int
lenPred Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2
        MutableByteArray s
dst' <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (Int
nextLenPred Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        MutableByteArray (PrimState (ST s))
-> Int
-> MutableByteArray (PrimState (ST s))
-> Int
-> Int
-> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ()
PM.copyMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst' Int
0 MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
0 Int
ix
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst' Int
ix (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Word8 (Char -> Int
ord Char
c))
        String -> MutableByteArray s -> Int -> Int -> ST s ByteArray
forall s.
String -> MutableByteArray s -> Int -> Int -> ST s ByteArray
go String
cs MutableByteArray s
dst' (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
nextLenPred

{- | Get a pointer to the payload of the managed C string. The behavior is
undefined if the argument is not pinned.
-}
contents :: ManagedCString -> CString
contents :: ManagedCString -> CString
contents (ManagedCString ByteArray
x) = Ptr Word8 -> CString
forall a b. Ptr a -> Ptr b
castPtr (ByteArray -> Ptr Word8
PM.byteArrayContents ByteArray
x)