{-|
Module      : Z.Foreign.CPtr
Description : Lightweight foreign pointer
Copyright   : (c) Dong Han, 2020
License     : BSD
Maintainer  : winterland1989@gmail.com
Stability   : experimental
Portability : non-portable

This module provide a lightweight foreign pointer, support c initializer and finalizer only.
-}

module Z.Foreign.CPtr (
  -- * CPtr type
    CPtr, newCPtr', newCPtrUnsafe, newCPtr, withCPtr, withCPtrsUnsafe, withCPtrs
  -- * Ptr type
  , Ptr
  , nullPtr
  , FunPtr
  ) where

import Control.Monad
import Control.Monad.Primitive
import Control.Exception                    (mask_)
import Data.Primitive.PrimArray
import qualified Z.Data.Text                as T
import GHC.Ptr
import GHC.Exts
import Z.Data.Array
import Z.Foreign

-- | Lightweight foreign pointers.
newtype CPtr a = CPtr (PrimArray (Ptr a))

instance Eq (CPtr a) where
    {-# INLINE (==) #-}
    CPtr PrimArray (Ptr a)
a == :: CPtr a -> CPtr a -> Bool
== CPtr PrimArray (Ptr a)
b = PrimArray (Ptr a) -> Int -> Ptr a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray (Ptr a)
a Int
0 Ptr a -> Ptr a -> Bool
forall a. Eq a => a -> a -> Bool
== PrimArray (Ptr a) -> Int -> Ptr a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray (Ptr a)
b Int
0

instance Ord (CPtr a) where
    {-# INLINE compare #-}
    CPtr PrimArray (Ptr a)
a compare :: CPtr a -> CPtr a -> Ordering
`compare` CPtr PrimArray (Ptr a)
b = PrimArray (Ptr a) -> Int -> Ptr a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray (Ptr a)
a Int
0 Ptr a -> Ptr a -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` PrimArray (Ptr a) -> Int -> Ptr a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray (Ptr a)
b Int
0

instance Show (CPtr a) where
    show :: CPtr a -> String
show = CPtr a -> String
forall a. Print a => a -> String
T.toString

instance T.Print (CPtr a) where
    {-# INLINE toUTF8BuilderP #-}
    toUTF8BuilderP :: Int -> CPtr a -> Builder ()
toUTF8BuilderP Int
_ (CPtr PrimArray (Ptr a)
mpa) = Int -> Ptr a -> Builder ()
forall a. Print a => Int -> a -> Builder ()
T.toUTF8BuilderP Int
0 (PrimArray (Ptr a) -> Int -> Ptr a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray (Ptr a)
mpa Int
0)

-- | Initialize a 'CPtr' with initializer which return an allocated pointer.
--
newCPtr' :: IO (Ptr a) -- ^ initializer
         -> FunPtr (Ptr a -> IO b) -- ^ finalizer
         -> IO (CPtr a)
newCPtr' :: IO (Ptr a) -> FunPtr (Ptr a -> IO b) -> IO (CPtr a)
newCPtr' IO (Ptr a)
ini (FunPtr Addr#
fin#) = IO (CPtr a) -> IO (CPtr a)
forall a. IO a -> IO a
mask_ (IO (CPtr a) -> IO (CPtr a)) -> IO (CPtr a) -> IO (CPtr a)
forall a b. (a -> b) -> a -> b
$ do
    MutablePrimArray RealWorld (Ptr a)
mpa <- Int -> IO (MutablePrimArray (PrimState IO) (Ptr a))
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
1
    p :: Ptr a
p@(Ptr Addr#
addr#) <- IO (Ptr a)
ini
    MutablePrimArray (PrimState IO) (Ptr a) -> Int -> Ptr a -> IO ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray RealWorld (Ptr a)
MutablePrimArray (PrimState IO) (Ptr a)
mpa Int
0 Ptr a
p
    pa :: PrimArray (Ptr a)
pa@(PrimArray ByteArray#
ba#) <- MutablePrimArray (PrimState IO) (Ptr a) -> IO (PrimArray (Ptr a))
forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray RealWorld (Ptr a)
MutablePrimArray (PrimState IO) (Ptr a)
mpa
    (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ ((State# (PrimState IO) -> State# (PrimState IO)) -> IO ())
-> (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ State# (PrimState IO)
s0# ->
        let !(# State# RealWorld
s1#, Weak# ()
w# #) = ByteArray#
-> () -> State# RealWorld -> (# State# RealWorld, Weak# () #)
forall a b.
a -> b -> State# RealWorld -> (# State# RealWorld, Weak# b #)
mkWeakNoFinalizer# ByteArray#
ba# () State# RealWorld
State# (PrimState IO)
s0#
            !(# State# RealWorld
s2#, Int#
_ #) = Addr#
-> Addr#
-> Int#
-> Addr#
-> Weak# ()
-> State# RealWorld
-> (# State# RealWorld, Int# #)
forall b.
Addr#
-> Addr#
-> Int#
-> Addr#
-> Weak# b
-> State# RealWorld
-> (# State# RealWorld, Int# #)
addCFinalizerToWeak# Addr#
fin# Addr#
addr# Int#
0# Addr#
addr# Weak# ()
w# State# RealWorld
s1#
        in State# RealWorld
State# (PrimState IO)
s2#
    CPtr a -> IO (CPtr a)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimArray (Ptr a) -> CPtr a
forall a. PrimArray (Ptr a) -> CPtr a
CPtr PrimArray (Ptr a)
pa)

-- | Initialize a 'CPtr' with initializer(must be unsafe FFI) and finalizer.
--
-- The initializer will receive a pointer of pointer so that it can do allocation and
-- write pointer back.
newCPtrUnsafe :: (MutableByteArray# RealWorld -> IO r) -- ^ initializer
              -> FunPtr (Ptr a -> IO b) -- ^ finalizer
              -> IO (CPtr a, r)
newCPtrUnsafe :: (MutableByteArray# RealWorld -> IO r)
-> FunPtr (Ptr a -> IO b) -> IO (CPtr a, r)
newCPtrUnsafe MutableByteArray# RealWorld -> IO r
ini (FunPtr Addr#
fin#) = IO (CPtr a, r) -> IO (CPtr a, r)
forall a. IO a -> IO a
mask_ (IO (CPtr a, r) -> IO (CPtr a, r))
-> IO (CPtr a, r) -> IO (CPtr a, r)
forall a b. (a -> b) -> a -> b
$ do
    mpa :: MutablePrimArray RealWorld (Ptr a)
mpa@(MutablePrimArray MutableByteArray# RealWorld
mba#) <- Int -> IO (MutablePrimArray (PrimState IO) (Ptr a))
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
1
    r
r <- MutableByteArray# RealWorld -> IO r
ini MutableByteArray# RealWorld
mba#
    (Ptr Addr#
addr#) <- MutablePrimArray (PrimState IO) (Ptr a) -> Int -> IO (Ptr a)
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutablePrimArray RealWorld (Ptr a)
MutablePrimArray (PrimState IO) (Ptr a)
mpa Int
0
    pa :: PrimArray (Ptr a)
pa@(PrimArray ByteArray#
ba#) <- MutablePrimArray (PrimState IO) (Ptr a) -> IO (PrimArray (Ptr a))
forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray RealWorld (Ptr a)
MutablePrimArray (PrimState IO) (Ptr a)
mpa
    (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ ((State# (PrimState IO) -> State# (PrimState IO)) -> IO ())
-> (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ State# (PrimState IO)
s0# ->
        let !(# State# RealWorld
s1#, Weak# ()
w# #) = ByteArray#
-> () -> State# RealWorld -> (# State# RealWorld, Weak# () #)
forall a b.
a -> b -> State# RealWorld -> (# State# RealWorld, Weak# b #)
mkWeakNoFinalizer# ByteArray#
ba# () State# RealWorld
State# (PrimState IO)
s0#
            !(# State# RealWorld
s2#, Int#
_ #) = Addr#
-> Addr#
-> Int#
-> Addr#
-> Weak# ()
-> State# RealWorld
-> (# State# RealWorld, Int# #)
forall b.
Addr#
-> Addr#
-> Int#
-> Addr#
-> Weak# b
-> State# RealWorld
-> (# State# RealWorld, Int# #)
addCFinalizerToWeak# Addr#
fin# Addr#
addr# Int#
0# Addr#
addr# Weak# ()
w# State# RealWorld
s1#
        in State# RealWorld
State# (PrimState IO)
s2#
    (CPtr a, r) -> IO (CPtr a, r)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimArray (Ptr a) -> CPtr a
forall a. PrimArray (Ptr a) -> CPtr a
CPtr PrimArray (Ptr a)
pa, r
r)

-- | Initialize a 'CPtr' with initializer and finalizer.
--
-- The initializer will receive a pointer of pointer so that it can do allocation and
-- write pointer back.
newCPtr :: (Ptr (Ptr a) -> IO r) -- ^ initializer
        -> FunPtr (Ptr a -> IO b) -- ^ finalizer
        -> IO (CPtr a, r)
newCPtr :: (Ptr (Ptr a) -> IO r) -> FunPtr (Ptr a -> IO b) -> IO (CPtr a, r)
newCPtr Ptr (Ptr a) -> IO r
ini (FunPtr Addr#
fin#) = IO (CPtr a, r) -> IO (CPtr a, r)
forall a. IO a -> IO a
mask_ (IO (CPtr a, r) -> IO (CPtr a, r))
-> IO (CPtr a, r) -> IO (CPtr a, r)
forall a b. (a -> b) -> a -> b
$ do
    MutablePrimArray RealWorld (Ptr a)
mpa <- Int -> IO (MutablePrimArray (PrimState IO) (Ptr a))
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPinnedPrimArray Int
1
    r
r <- Ptr (Ptr a) -> IO r
ini (MutablePrimArray RealWorld (Ptr a) -> Ptr (Ptr a)
forall s a. MutablePrimArray s a -> Ptr a
mutablePrimArrayContents MutablePrimArray RealWorld (Ptr a)
mpa)
    (Ptr Addr#
addr#) <- MutablePrimArray (PrimState IO) (Ptr a) -> Int -> IO (Ptr a)
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutablePrimArray RealWorld (Ptr a)
MutablePrimArray (PrimState IO) (Ptr a)
mpa Int
0
    pa :: PrimArray (Ptr a)
pa@(PrimArray ByteArray#
ba#) <- MutablePrimArray (PrimState IO) (Ptr a) -> IO (PrimArray (Ptr a))
forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray RealWorld (Ptr a)
MutablePrimArray (PrimState IO) (Ptr a)
mpa
    (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ ((State# (PrimState IO) -> State# (PrimState IO)) -> IO ())
-> (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ State# (PrimState IO)
s0# ->
        let !(# State# RealWorld
s1#, Weak# ()
w# #) = ByteArray#
-> () -> State# RealWorld -> (# State# RealWorld, Weak# () #)
forall a b.
a -> b -> State# RealWorld -> (# State# RealWorld, Weak# b #)
mkWeakNoFinalizer# ByteArray#
ba# () State# RealWorld
State# (PrimState IO)
s0#
            !(# State# RealWorld
s2#, Int#
_ #) = Addr#
-> Addr#
-> Int#
-> Addr#
-> Weak# ()
-> State# RealWorld
-> (# State# RealWorld, Int# #)
forall b.
Addr#
-> Addr#
-> Int#
-> Addr#
-> Weak# b
-> State# RealWorld
-> (# State# RealWorld, Int# #)
addCFinalizerToWeak# Addr#
fin# Addr#
addr# Int#
0# Addr#
addr# Weak# ()
w# State# RealWorld
s1#
        in State# RealWorld
State# (PrimState IO)
s2#
    (CPtr a, r) -> IO (CPtr a, r)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimArray (Ptr a) -> CPtr a
forall a. PrimArray (Ptr a) -> CPtr a
CPtr PrimArray (Ptr a)
pa, r
r)

-- | The only way to use 'CPtr' as a 'Ptr' in FFI is to use 'withCPtr'.
withCPtr :: CPtr a -> (Ptr a -> IO b) -> IO b
withCPtr :: CPtr a -> (Ptr a -> IO b) -> IO b
withCPtr (CPtr pa :: PrimArray (Ptr a)
pa@(PrimArray ByteArray#
ba#)) Ptr a -> IO b
f = do
    b
r <- Ptr a -> IO b
f (PrimArray (Ptr a) -> Int -> Ptr a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray (Ptr a)
pa Int
0)
    (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ (ByteArray# -> State# RealWorld -> State# RealWorld
forall a. a -> State# RealWorld -> State# RealWorld
touch# ByteArray#
ba#)
    b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r

-- | Pass a list of 'CPtr Foo' as @foo**@. USE THIS FUNCTION WITH UNSAFE FFI ONLY!
withCPtrsUnsafe :: forall a b. [CPtr a] -> (BA# (Ptr a) -> Int -> IO b) -> IO b
withCPtrsUnsafe :: [CPtr a] -> (ByteArray# -> Int -> IO b) -> IO b
withCPtrsUnsafe [CPtr a]
cptrs ByteArray# -> Int -> IO b
f = do
    MutablePrimArray RealWorld (Ptr a)
mpa <- Int -> IO (MutablePrimArray (PrimState IO) (Ptr a))
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray @IO @(Ptr a) Int
len
    (Int -> CPtr a -> IO Int) -> Int -> [CPtr a] -> IO ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ (\ !Int
i (CPtr PrimArray (Ptr a)
pa) ->
        MutablePrimArray (PrimState IO) (Ptr a) -> Int -> Ptr a -> IO ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray RealWorld (Ptr a)
MutablePrimArray (PrimState IO) (Ptr a)
mpa Int
i (PrimArray (Ptr a) -> Int -> Ptr a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray (Ptr a)
pa Int
0) IO () -> IO Int -> IO Int
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> IO Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)) Int
0 [CPtr a]
cptrs
    (PrimArray ByteArray#
ba#) <- MutablePrimArray (PrimState IO) (Ptr a) -> IO (PrimArray (Ptr a))
forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray RealWorld (Ptr a)
MutablePrimArray (PrimState IO) (Ptr a)
mpa
    b
r <- ByteArray# -> Int -> IO b
f ByteArray#
ba# Int
len
    (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ ([CPtr a] -> State# RealWorld -> State# RealWorld
forall a. a -> State# RealWorld -> State# RealWorld
touch# [CPtr a]
cptrs)
    b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r
  where len :: Int
len = [CPtr a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CPtr a]
cptrs

-- | Pass a list of 'CPtr Foo' as @foo**@.
withCPtrs :: forall a b. [CPtr a] -> (Ptr (Ptr a) -> Int -> IO b) -> IO b
withCPtrs :: [CPtr a] -> (Ptr (Ptr a) -> Int -> IO b) -> IO b
withCPtrs [CPtr a]
cptrs Ptr (Ptr a) -> Int -> IO b
f = do
    MutablePrimArray RealWorld (Ptr a)
mpa <- Int -> IO (MutablePrimArray (PrimState IO) (Ptr a))
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPinnedPrimArray @IO @(Ptr a) Int
len
    (Int -> CPtr a -> IO Int) -> Int -> [CPtr a] -> IO ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ (\ !Int
i (CPtr PrimArray (Ptr a)
pa) ->
        MutablePrimArray (PrimState IO) (Ptr a) -> Int -> Ptr a -> IO ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray RealWorld (Ptr a)
MutablePrimArray (PrimState IO) (Ptr a)
mpa Int
i (PrimArray (Ptr a) -> Int -> Ptr a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray (Ptr a)
pa Int
0) IO () -> IO Int -> IO Int
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> IO Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)) Int
0 [CPtr a]
cptrs
    b
r <- MutablePrimArray RealWorld (Ptr a) -> (Ptr (Ptr a) -> IO b) -> IO b
forall a b. MutablePrimArray RealWorld a -> (Ptr a -> IO b) -> IO b
withMutablePrimArrayContents MutablePrimArray RealWorld (Ptr a)
mpa ((Ptr (Ptr a) -> IO b) -> IO b) -> (Ptr (Ptr a) -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \ Ptr (Ptr a)
p -> Ptr (Ptr a) -> Int -> IO b
f Ptr (Ptr a)
p Int
len
    (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ ([CPtr a] -> State# RealWorld -> State# RealWorld
forall a. a -> State# RealWorld -> State# RealWorld
touch# [CPtr a]
cptrs)
    b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r
  where len :: Int
len = [CPtr a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CPtr a]
cptrs