{-# LANGUAGE CPP #-}
{-|
Module      : Lua.Call
Copyright   : © 2007–2012 Gracjan Polak;
              © 2012–2016 Ömer Sinan Ağacan;
              © 2017-2024 Albert Krewinkel
License     : MIT
Maintainer  : Albert Krewinkel <tarleb@hslua.org>
Stability   : beta
Portability : non-portable (depends on GHC)

Function to push Haskell functions as Lua C functions.

Haskell functions are converted into C functions in a two-step process.
First, a function pointer to the Haskell function is stored in a Lua
userdata object. The userdata gets a metatable which allows to invoke
the object as a function. The userdata also ensures that the function
pointer is freed when the object is garbage collected in Lua.

In a second step, the userdata is then wrapped into a C closure. The
wrapping function calls the userdata object and implements the error
protocol, converting special error values into proper Lua errors.
-}
module Lua.Call
  ( hslua_pushhsfunction
  ) where

import Foreign.C (CInt (CInt))
import Foreign.Ptr (Ptr, castPtr, nullPtr)
import Foreign.StablePtr (StablePtr, deRefStablePtr, newStablePtr)
import Foreign.Storable (peek)
import Lua.Types
  ( NumResults (NumResults)
  , PreCFunction
  , State (State)
  )

#ifdef ALLOW_UNSAFE_GC
#define SAFTY unsafe
#else
#define SAFTY safe
#endif

-- | Retrieve the pointer to a Haskell function from the wrapping
-- userdata object.
foreign import ccall SAFTY "hslcall.c hslua_extracthsfun"
  hslua_extracthsfun :: State -> IO (Ptr ())

-- | Creates a new C function created from a 'PreCFunction'. The
-- function pointer to the PreCFunction is stored in a userdata object,
-- which is then wrapped by a C closure. The userdata object ensures
-- that the function pointer is freed when the function is garbage
-- collected in Lua.
foreign import ccall SAFTY "hslcall.c hslua_newhsfunction"
  hslua_newhsfunction :: State -> StablePtr a -> IO ()

-- | Pushes a Haskell operation as a Lua function. The Haskell operation
-- is expected to follow the custom error protocol, i.e., it must signal
-- errors with @'Lua.hslua_error'@.
--
-- === Example
-- Export the function to calculate triangular numbers.
--
-- > let triangular :: PreCFunction
-- >     triangular l' = do
-- >       n <- lua_tointegerx l' (nthBottom 1) nullPtr
-- >       lua_pushinteger l' (sum [1..n])
-- >       return (NumResults 1)
-- >
-- > hslua_newhsfunction l triangular
-- > withCString "triangular" (lua_setglobal l)
--
hslua_pushhsfunction :: State -> PreCFunction -> IO ()
hslua_pushhsfunction :: State -> (State -> IO NumResults) -> IO ()
hslua_pushhsfunction State
l State -> IO NumResults
preCFn =
  (State -> IO NumResults) -> IO (StablePtr (State -> IO NumResults))
forall a. a -> IO (StablePtr a)
newStablePtr State -> IO NumResults
preCFn IO (StablePtr (State -> IO NumResults))
-> (StablePtr (State -> IO NumResults) -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= State -> StablePtr (State -> IO NumResults) -> IO ()
forall a. State -> StablePtr a -> IO ()
hslua_newhsfunction State
l
{-# INLINABLE hslua_pushhsfunction #-}

-- | Call the Haskell function stored in the userdata. This
-- function is exported as a C function, as the C code uses it as
-- the @__call@ value of the wrapping userdata metatable.
hslua_callhsfun :: PreCFunction
hslua_callhsfun :: State -> IO NumResults
hslua_callhsfun State
l = do
  Ptr ()
udPtr <- State -> IO (Ptr ())
hslua_extracthsfun State
l
  if Ptr ()
udPtr Ptr () -> Ptr () -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr ()
forall a. Ptr a
nullPtr
    then [Char] -> IO NumResults
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot call function; corrupted Lua object!"
    else do
      State -> IO NumResults
fn <- Ptr (StablePtr (State -> IO NumResults))
-> IO (StablePtr (State -> IO NumResults))
forall a. Storable a => Ptr a -> IO a
peek (Ptr () -> Ptr (StablePtr (State -> IO NumResults))
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
udPtr) IO (StablePtr (State -> IO NumResults))
-> (StablePtr (State -> IO NumResults)
    -> IO (State -> IO NumResults))
-> IO (State -> IO NumResults)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StablePtr (State -> IO NumResults) -> IO (State -> IO NumResults)
forall a. StablePtr a -> IO a
deRefStablePtr
      State -> IO NumResults
fn State
l

foreign export ccall hslua_callhsfun :: PreCFunction