{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{- |
   Module      : Text.Pandoc.Lua.Util
   Copyright   : © 2012-2021 John MacFarlane,
                 © 2017-2021 Albert Krewinkel
   License     : GNU GPL, version 2 or above

   Maintainer  : Albert Krewinkel <tarleb+pandoc@moltkeplatz.de>
   Stability   : alpha

Lua utility functions.
-}
module Text.Pandoc.Lua.Util
  ( getTag
  , addField
  , addFunction
  , pushViaConstructor
  , callWithTraceback
  , dofileWithTraceback
  , pushViaConstr'
  ) where

import Control.Monad (unless, when)
import HsLua
import qualified HsLua as Lua

-- | Add a value to the table at the top of the stack at a string-index.
addField :: (LuaError e, Pushable a) => String -> a -> LuaE e ()
addField :: String -> a -> LuaE e ()
addField String
key a
value = do
  String -> LuaE e ()
forall a e. (Pushable a, LuaError e) => a -> LuaE e ()
Lua.push String
key
  a -> LuaE e ()
forall a e. (Pushable a, LuaError e) => a -> LuaE e ()
Lua.push a
value
  StackIndex -> LuaE e ()
forall e. LuaError e => StackIndex -> LuaE e ()
Lua.rawset (CInt -> StackIndex
Lua.nth CInt
3)

-- | Add a function to the table at the top of the stack, using the
-- given name.
addFunction :: Exposable e a => String -> a -> LuaE e ()
addFunction :: String -> a -> LuaE e ()
addFunction String
name a
fn = do
  String -> LuaE e ()
forall a e. (Pushable a, LuaError e) => a -> LuaE e ()
Lua.push String
name
  HaskellFunction e -> LuaE e ()
forall e. LuaError e => HaskellFunction e -> LuaE e ()
Lua.pushHaskellFunction (HaskellFunction e -> LuaE e ()) -> HaskellFunction e -> LuaE e ()
forall a b. (a -> b) -> a -> b
$ a -> HaskellFunction e
forall e a. Exposable e a => a -> HaskellFunction e
toHaskellFunction a
fn
  StackIndex -> LuaE e ()
forall e. LuaError e => StackIndex -> LuaE e ()
Lua.rawset (-StackIndex
3)

-- | Helper class for pushing a single value to the stack via a lua
-- function. See @pushViaCall@.
class LuaError e => PushViaCall e a where
  pushViaCall' :: LuaError e => Name -> LuaE e () -> NumArgs -> a

instance LuaError e => PushViaCall e (LuaE e ()) where
  pushViaCall' :: Name -> LuaE e () -> NumArgs -> LuaE e ()
pushViaCall' Name
fn LuaE e ()
pushArgs NumArgs
num = do
    Name -> LuaE e ()
forall e. Name -> LuaE e ()
Lua.pushName @e Name
fn
    StackIndex -> LuaE e ()
forall e. LuaError e => StackIndex -> LuaE e ()
Lua.rawget StackIndex
Lua.registryindex
    LuaE e ()
pushArgs
    NumArgs -> NumResults -> LuaE e ()
forall e. LuaError e => NumArgs -> NumResults -> LuaE e ()
Lua.call NumArgs
num NumResults
1

instance (LuaError e, Pushable a, PushViaCall e b) =>
         PushViaCall e (a -> b) where
  pushViaCall' :: Name -> LuaE e () -> NumArgs -> a -> b
pushViaCall' Name
fn LuaE e ()
pushArgs NumArgs
num a
x =
    Name -> LuaE e () -> NumArgs -> b
forall e a.
(PushViaCall e a, LuaError e) =>
Name -> LuaE e () -> NumArgs -> a
pushViaCall' @e Name
fn (LuaE e ()
pushArgs LuaE e () -> LuaE e () -> LuaE e ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> a -> LuaE e ()
forall a e. (Pushable a, LuaError e) => a -> LuaE e ()
Lua.push a
x) (NumArgs
num NumArgs -> NumArgs -> NumArgs
forall a. Num a => a -> a -> a
+ NumArgs
1)

-- | Push an value to the stack via a lua function. The lua function is called
-- with all arguments that are passed to this function and is expected to return
-- a single value.
pushViaCall :: forall e a. LuaError e => PushViaCall e a => Name -> a
pushViaCall :: Name -> a
pushViaCall Name
fn = Name -> LuaE e () -> NumArgs -> a
forall e a.
(PushViaCall e a, LuaError e) =>
Name -> LuaE e () -> NumArgs -> a
pushViaCall' @e Name
fn (() -> LuaE e ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) NumArgs
0

-- | Call a pandoc element constructor within Lua, passing all given arguments.
pushViaConstructor :: forall e a. LuaError e => PushViaCall e a => Name -> a
pushViaConstructor :: Name -> a
pushViaConstructor Name
pandocFn = Name -> a
forall e a. (LuaError e, PushViaCall e a) => Name -> a
pushViaCall @e (Name
"pandoc." Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
pandocFn)

-- | Get the tag of a value. This is an optimized and specialized version of
-- @Lua.getfield idx "tag"@. It only checks for the field on the table at index
-- @idx@ and on its metatable, also ignoring any @__index@ value on the
-- metatable.
getTag :: LuaError e => Peeker e Name
getTag :: Peeker e Name
getTag StackIndex
idx = do
  -- push metatable or just the table
  LuaE e () -> Peek e ()
forall e a. LuaE e a -> Peek e a
liftLua (LuaE e () -> Peek e ()) -> LuaE e () -> Peek e ()
forall a b. (a -> b) -> a -> b
$ do
    StackIndex -> LuaE e Bool
forall e. StackIndex -> LuaE e Bool
Lua.getmetatable StackIndex
idx LuaE e Bool -> (Bool -> LuaE e ()) -> LuaE e ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Bool
hasMT -> Bool -> LuaE e () -> LuaE e ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
hasMT (StackIndex -> LuaE e ()
forall e. StackIndex -> LuaE e ()
Lua.pushvalue StackIndex
idx)
    Name -> LuaE e ()
forall e. Name -> LuaE e ()
Lua.pushName Name
"tag"
    StackIndex -> LuaE e ()
forall e. LuaError e => StackIndex -> LuaE e ()
Lua.rawget (CInt -> StackIndex
Lua.nth CInt
2)
  Peeker e Name
forall e. Peeker e Name
Lua.peekName StackIndex
Lua.top Peek e Name -> LuaE e () -> Peek e Name
forall e a b. Peek e a -> LuaE e b -> Peek e a
`lastly` Int -> LuaE e ()
forall e. Int -> LuaE e ()
Lua.pop Int
2  -- table/metatable and `tag` field

pushViaConstr' :: forall e. LuaError e => Name -> [LuaE e ()] -> LuaE e ()
pushViaConstr' :: Name -> [LuaE e ()] -> LuaE e ()
pushViaConstr' Name
fnname [LuaE e ()]
pushArgs = do
  Name -> LuaE e ()
forall e. Name -> LuaE e ()
pushName @e (Name
"pandoc." Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
fnname)
  StackIndex -> LuaE e ()
forall e. LuaError e => StackIndex -> LuaE e ()
rawget @e StackIndex
registryindex
  [LuaE e ()] -> LuaE e ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [LuaE e ()]
pushArgs
  NumArgs -> NumResults -> LuaE e ()
forall e. LuaError e => NumArgs -> NumResults -> LuaE e ()
call @e (Int -> NumArgs
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LuaE e ()] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LuaE e ()]
pushArgs)) NumResults
1

-- | Like @'Lua.pcall'@, but uses a predefined error handler which adds a
-- traceback on error.
pcallWithTraceback :: LuaError e => NumArgs -> NumResults -> LuaE e Status
pcallWithTraceback :: NumArgs -> NumResults -> LuaE e Status
pcallWithTraceback NumArgs
nargs NumResults
nresults = do
  let traceback' :: LuaError e => LuaE e NumResults
      traceback' :: LuaE e NumResults
traceback' = do
        State
l <- LuaE e State
forall e. LuaE e State
Lua.state
        ByteString
msg <- StackIndex -> LuaE e ByteString
forall e. LuaError e => StackIndex -> LuaE e ByteString
Lua.tostring' (CInt -> StackIndex
Lua.nthBottom CInt
1)
        State -> Maybe ByteString -> Int -> LuaE e ()
forall e. State -> Maybe ByteString -> Int -> LuaE e ()
Lua.traceback State
l (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
msg) Int
2
        NumResults -> LuaE e NumResults
forall (m :: * -> *) a. Monad m => a -> m a
return NumResults
1
  StackIndex
tracebackIdx <- StackIndex -> LuaE e StackIndex
forall e. StackIndex -> LuaE e StackIndex
Lua.absindex (CInt -> StackIndex
Lua.nth (NumArgs -> CInt
Lua.fromNumArgs NumArgs
nargs CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
+ CInt
1))
  HaskellFunction e -> LuaE e ()
forall e. LuaError e => HaskellFunction e -> LuaE e ()
Lua.pushHaskellFunction HaskellFunction e
forall e. LuaError e => LuaE e NumResults
traceback'
  StackIndex -> LuaE e ()
forall e. StackIndex -> LuaE e ()
Lua.insert StackIndex
tracebackIdx
  Status
result <- NumArgs -> NumResults -> Maybe StackIndex -> LuaE e Status
forall e.
NumArgs -> NumResults -> Maybe StackIndex -> LuaE e Status
Lua.pcall NumArgs
nargs NumResults
nresults (StackIndex -> Maybe StackIndex
forall a. a -> Maybe a
Just StackIndex
tracebackIdx)
  StackIndex -> LuaE e ()
forall e. StackIndex -> LuaE e ()
Lua.remove StackIndex
tracebackIdx
  Status -> LuaE e Status
forall (m :: * -> *) a. Monad m => a -> m a
return Status
result

-- | Like @'Lua.call'@, but adds a traceback to the error message (if any).
callWithTraceback :: LuaError e => NumArgs -> NumResults -> LuaE e ()
callWithTraceback :: NumArgs -> NumResults -> LuaE e ()
callWithTraceback NumArgs
nargs NumResults
nresults = do
  Status
result <- NumArgs -> NumResults -> LuaE e Status
forall e. LuaError e => NumArgs -> NumResults -> LuaE e Status
pcallWithTraceback NumArgs
nargs NumResults
nresults
  Bool -> LuaE e () -> LuaE e ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Status
result Status -> Status -> Bool
forall a. Eq a => a -> a -> Bool
/= Status
Lua.OK)
    LuaE e ()
forall e a. LuaError e => LuaE e a
Lua.throwErrorAsException

-- | Run the given string as a Lua program, while also adding a traceback to the
-- error message if an error occurs.
dofileWithTraceback :: LuaError e => FilePath -> LuaE e Status
dofileWithTraceback :: String -> LuaE e Status
dofileWithTraceback String
fp = do
  Status
loadRes <- String -> LuaE e Status
forall e. String -> LuaE e Status
Lua.loadfile String
fp
  case Status
loadRes of
    Status
Lua.OK -> NumArgs -> NumResults -> LuaE e Status
forall e. LuaError e => NumArgs -> NumResults -> LuaE e Status
pcallWithTraceback NumArgs
0 NumResults
Lua.multret
    Status
_ -> Status -> LuaE e Status
forall (m :: * -> *) a. Monad m => a -> m a
return Status
loadRes