{-# LANGUAGE TypeApplications #-}

module Extism.PDK.Bindings where

import Control.Monad
import Data.ByteString as B
import Data.ByteString.Internal
import Data.Int
import Data.Word
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import System.Exit

-- | Offset in Extism memory
type MemoryOffset = Word64

-- | Offset of input from 0 to 'InputLength'
type InputOffset = Word64

-- | Length of allocated block of memory
type MemoryLength = Word64

-- | Total length of the input
type InputLength = Word64

foreign import ccall "extism_output_set" extismSetOutput :: MemoryOffset -> MemoryLength -> IO ()

foreign import ccall "extism_error_set" extismSetError :: MemoryOffset -> IO ()

foreign import ccall "extism_log_info" extismLogInfo :: MemoryOffset -> IO ()

foreign import ccall "extism_log_warn" extismLogWarn :: MemoryOffset -> IO ()

foreign import ccall "extism_log_debug" extismLogDebug :: MemoryOffset -> IO ()

foreign import ccall "extism_log_error" extismLogError :: MemoryOffset -> IO ()

foreign import ccall "extism_store_u8" extismStoreU8 :: MemoryOffset -> Word8 -> IO ()

foreign import ccall "extism_store_u64" extismStoreU64 :: MemoryOffset -> Word64 -> IO ()

foreign import ccall "extism_load_u8" extismLoadU8 :: MemoryOffset -> IO Word8

foreign import ccall "extism_load_u64" extismLoadU64 :: MemoryOffset -> IO Word64

foreign import ccall "extism_alloc" extismAlloc :: MemoryLength -> IO MemoryOffset

foreign import ccall "extism_length" extismLength :: MemoryOffset -> IO MemoryLength

foreign import ccall "extism_length_unsafe" extismLengthUnsafe :: MemoryOffset -> IO MemoryLength

foreign import ccall "extism_free" extismFree :: MemoryOffset -> IO ()

foreign import ccall "extism_input_length" extismInputLength :: IO InputLength

foreign import ccall "extism_input_load_u8" extismInputLoadU8 :: InputOffset -> IO Word8

foreign import ccall "extism_input_load_u64" extismInputLoadU64 :: InputOffset -> IO Word64

foreign import ccall "extism_config_get" extismGetConfig :: MemoryOffset -> IO MemoryOffset

foreign import ccall "extism_var_get" extismGetVar :: MemoryOffset -> IO MemoryOffset

foreign import ccall "extism_var_set" extismSetVar :: MemoryOffset -> MemoryOffset -> IO ()

foreign import ccall "extism_http_request" extismHTTPRequest :: MemoryOffset -> MemoryOffset -> IO MemoryOffset

foreign import ccall "extism_http_status_code" extismHTTPStatusCode :: IO Int32

foreign import ccall "__wasm_call_ctors" wasmConstructor :: IO ()

foreign import ccall "__wasm_call_dtors" wasmDestructor :: IO ()

bsToWord64 :: ByteString -> IO Word64
bsToWord64 :: ByteString -> IO Word64
bsToWord64 (BS ForeignPtr Word8
fp Int
len) =
  if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
8
    then [Char] -> IO Word64
forall a. HasCallStack => [Char] -> a
error [Char]
"invalid bytestring"
    else
      ForeignPtr Word8 -> (Ptr Word8 -> IO Word64) -> IO Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr
        ForeignPtr Word8
fp
        ( Ptr Word64 -> IO Word64
forall a. Storable a => Ptr a -> IO a
peek (Ptr Word64 -> IO Word64)
-> (Ptr Word8 -> Ptr Word64) -> Ptr Word8 -> IO Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Ptr a -> Ptr b
castPtr @Word8 @Word64
        )

word64ToBS :: Word64 -> ByteString
word64ToBS :: Word64 -> ByteString
word64ToBS Word64
word =
  Int -> (Ptr Word8 -> IO ()) -> ByteString
unsafeCreate
    Int
8
    ( \Ptr Word8
p ->
        Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr @Word8 @Word64 Ptr Word8
p) Word64
word
    )

readLoop :: (Word64 -> IO Word8) -> (Word64 -> IO Word64) -> Word64 -> Word64 -> [ByteString] -> IO ByteString
readLoop :: (Word64 -> IO Word8)
-> (Word64 -> IO Word64)
-> Word64
-> Word64
-> [ByteString]
-> IO ByteString
readLoop Word64 -> IO Word8
f1 Word64 -> IO Word64
f8 Word64
total Word64
index [ByteString]
acc =
  if Word64
index Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
total
    then ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
forall a. [a] -> [a]
Prelude.reverse ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString]
acc
    else
      let diff :: Word64
diff = Word64
total Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
index
       in do
            (Word64
n, ByteString
x) <-
              if Word64
diff Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
8
                then do
                  Word64
u <- Word64 -> IO Word64
f8 Word64
index
                  (Word64, ByteString) -> IO (Word64, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word64
8, Word64 -> ByteString
word64ToBS Word64
u)
                else do
                  Word8
b <- Word64 -> IO Word8
f1 Word64
index
                  (Word64, ByteString) -> IO (Word64, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word64
1, Word8 -> ByteString
B.singleton Word8
b)
            (Word64 -> IO Word8)
-> (Word64 -> IO Word64)
-> Word64
-> Word64
-> [ByteString]
-> IO ByteString
readLoop Word64 -> IO Word8
f1 Word64 -> IO Word64
f8 Word64
total (Word64
index Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
n) (ByteString
x ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
acc)

readInputBytes :: InputLength -> IO ByteString
readInputBytes :: Word64 -> IO ByteString
readInputBytes Word64
len =
  (Word64 -> IO Word8)
-> (Word64 -> IO Word64)
-> Word64
-> Word64
-> [ByteString]
-> IO ByteString
readLoop Word64 -> IO Word8
extismInputLoadU8 Word64 -> IO Word64
extismInputLoadU64 Word64
len Word64
0 []

readBytes :: MemoryOffset -> MemoryLength -> IO ByteString
readBytes :: Word64 -> Word64 -> IO ByteString
readBytes Word64
offs Word64
len =
  (Word64 -> IO Word8)
-> (Word64 -> IO Word64)
-> Word64
-> Word64
-> [ByteString]
-> IO ByteString
readLoop Word64 -> IO Word8
extismLoadU8 Word64 -> IO Word64
extismLoadU64 (Word64
offs Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
len) Word64
offs []

writeBytesLoop :: MemoryOffset -> MemoryOffset -> ByteString -> IO ()
writeBytesLoop :: Word64 -> Word64 -> ByteString -> IO ()
writeBytesLoop Word64
index Word64
total ByteString
src =
  if Word64
index Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
total
    then () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    else
      let diff :: Word64
diff = Word64
total Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
index
       in do
            (Word64
n, ByteString
sub) <-
              if Word64
diff Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
8
                then do
                  let (ByteString
curr, ByteString
next) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
8 ByteString
src
                  Word64
u <- ByteString -> IO Word64
bsToWord64 ByteString
curr
                  Word64 -> Word64 -> IO ()
extismStoreU64 Word64
index Word64
u
                  (Word64, ByteString) -> IO (Word64, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word64
8, ByteString
next)
                else do
                  let u :: Word8
u = HasCallStack => ByteString -> Word8
ByteString -> Word8
B.head ByteString
src
                  Word64 -> Word8 -> IO ()
extismStoreU8 Word64
index Word8
u
                  (Word64, ByteString) -> IO (Word64, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word64
1, HasCallStack => ByteString -> ByteString
ByteString -> ByteString
B.tail ByteString
src)
            Word64 -> Word64 -> ByteString -> IO ()
writeBytesLoop (Word64
index Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
n) Word64
total ByteString
sub

writeBytes :: MemoryOffset -> MemoryLength -> ByteString -> IO ()
writeBytes :: Word64 -> Word64 -> ByteString -> IO ()
writeBytes Word64
offs Word64
len =
  Word64 -> Word64 -> ByteString -> IO ()
writeBytesLoop Word64
offs (Word64
offs Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
len)