{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}

-- | A module providing a backend that launches solvers as external processes.
module SMTLIB.Backends.Process
  ( Config (..),
    Handle (..),
    new,
    wait,
    close,
    with,
    toBackend,
  )
where

import Control.Concurrent.Async (Async, async, cancel)
import qualified Control.Exception as X
import Control.Monad (forever)
import Data.ByteString.Builder
  ( Builder,
    byteString,
    hPutBuilder,
    toLazyByteString,
  )
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy.Char8 as LBS
import Data.Default (Default, def)
import SMTLIB.Backends (Backend (..))
import System.Exit (ExitCode)
import qualified System.IO as IO
import System.Process.Typed
  ( Process,
    getStderr,
    getStdin,
    getStdout,
    mkPipeStreamSpec,
    setStderr,
    setStdin,
    setStdout,
    startProcess,
    stopProcess,
    waitExitCode,
  )
import qualified System.Process.Typed as P (proc)

data Config = Config
  { -- | The command to call to run the solver.
    Config -> String
exe :: String,
    -- | Arguments to pass to the solver's command.
    Config -> [String]
args :: [String],
    -- | A function for logging the solver process' messages on stderr and file
    -- handle exceptions.
    -- If you want line breaks between each log message, you need to implement
    -- it yourself, e.g use @'LBS.putStr' . (<> "\n")@.
    Config -> ByteString -> IO ()
reportError :: LBS.ByteString -> IO ()
  }

-- | By default, use Z3 as an external process and ignore log messages.
instance Default Config where
  -- if you change this, make sure to also update the comment two lines above
  -- as well as the one in @smtlib-backends-process/tests/Examples.hs@
  def :: Config
def = String -> [String] -> (ByteString -> IO ()) -> Config
Config String
"z3" [String
"-in"] ((ByteString -> IO ()) -> Config)
-> (ByteString -> IO ()) -> Config
forall a b. (a -> b) -> a -> b
$ IO () -> ByteString -> IO ()
forall a b. a -> b -> a
const (IO () -> ByteString -> IO ()) -> IO () -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

data Handle = Handle
  { -- | The process running the solver.
    Handle -> Process Handle Handle Handle
process :: Process IO.Handle IO.Handle IO.Handle,
    -- | A process reading the solver's error messages and logging them.
    Handle -> Async ()
errorReader :: Async ()
  }

-- | Run a solver as a process.
-- Failures relative to terminating the process are logged and discarded.
new ::
  -- | The solver process' configuration.
  Config ->
  IO Handle
new :: Config -> IO Handle
new Config
config = do
  Process Handle Handle Handle
solverProcess <-
    ProcessConfig Handle Handle Handle
-> IO (Process Handle Handle Handle)
forall (m :: * -> *) stdin stdout stderr.
MonadIO m =>
ProcessConfig stdin stdout stderr
-> m (Process stdin stdout stderr)
startProcess (ProcessConfig Handle Handle Handle
 -> IO (Process Handle Handle Handle))
-> ProcessConfig Handle Handle Handle
-> IO (Process Handle Handle Handle)
forall a b. (a -> b) -> a -> b
$
      StreamSpec 'STInput Handle
-> ProcessConfig () Handle Handle
-> ProcessConfig Handle Handle Handle
forall stdin stdin0 stdout stderr.
StreamSpec 'STInput stdin
-> ProcessConfig stdin0 stdout stderr
-> ProcessConfig stdin stdout stderr
setStdin StreamSpec 'STInput Handle
forall {streamType :: StreamType}. StreamSpec streamType Handle
createLoggedPipe (ProcessConfig () Handle Handle
 -> ProcessConfig Handle Handle Handle)
-> ProcessConfig () Handle Handle
-> ProcessConfig Handle Handle Handle
forall a b. (a -> b) -> a -> b
$
        StreamSpec 'STOutput Handle
-> ProcessConfig () () Handle -> ProcessConfig () Handle Handle
forall stdout stdin stdout0 stderr.
StreamSpec 'STOutput stdout
-> ProcessConfig stdin stdout0 stderr
-> ProcessConfig stdin stdout stderr
setStdout StreamSpec 'STOutput Handle
forall {streamType :: StreamType}. StreamSpec streamType Handle
createLoggedPipe (ProcessConfig () () Handle -> ProcessConfig () Handle Handle)
-> ProcessConfig () () Handle -> ProcessConfig () Handle Handle
forall a b. (a -> b) -> a -> b
$
          StreamSpec 'STOutput Handle
-> ProcessConfig () () () -> ProcessConfig () () Handle
forall stderr stdin stdout stderr0.
StreamSpec 'STOutput stderr
-> ProcessConfig stdin stdout stderr0
-> ProcessConfig stdin stdout stderr
setStderr StreamSpec 'STOutput Handle
forall {streamType :: StreamType}. StreamSpec streamType Handle
createLoggedPipe (ProcessConfig () () () -> ProcessConfig () () Handle)
-> ProcessConfig () () () -> ProcessConfig () () Handle
forall a b. (a -> b) -> a -> b
$
            String -> [String] -> ProcessConfig () () ()
P.proc (Config -> String
exe Config
config) (Config -> [String]
args Config
config)
  -- log error messages created by the backend
  Async ()
solverErrorReader <-
    IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$
      IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever
        ( do
            ByteString
errs <- Handle -> IO ByteString
BS.hGetLine (Handle -> IO ByteString) -> Handle -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Process Handle Handle Handle -> Handle
forall stdin stdout stderr. Process stdin stdout stderr -> stderr
getStderr Process Handle Handle Handle
solverProcess
            ByteString -> IO ()
reportError' ByteString
errs
        )
        IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`X.catch` \X.SomeException {} ->
          () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  Handle -> IO Handle
forall (m :: * -> *) a. Monad m => a -> m a
return (Handle -> IO Handle) -> Handle -> IO Handle
forall a b. (a -> b) -> a -> b
$ Process Handle Handle Handle -> Async () -> Handle
Handle Process Handle Handle Handle
solverProcess Async ()
solverErrorReader
  where
    createLoggedPipe :: StreamSpec streamType Handle
createLoggedPipe =
      (ProcessConfig () () () -> Handle -> IO (Handle, IO ()))
-> StreamSpec streamType Handle
forall a (streamType :: StreamType).
(ProcessConfig () () () -> Handle -> IO (a, IO ()))
-> StreamSpec streamType a
mkPipeStreamSpec ((ProcessConfig () () () -> Handle -> IO (Handle, IO ()))
 -> StreamSpec streamType Handle)
-> (ProcessConfig () () () -> Handle -> IO (Handle, IO ()))
-> StreamSpec streamType Handle
forall a b. (a -> b) -> a -> b
$ \ProcessConfig () () ()
_ Handle
h -> do
        Handle -> Bool -> IO ()
IO.hSetBinaryMode Handle
h Bool
True
        Handle -> BufferMode -> IO ()
IO.hSetBuffering Handle
h (BufferMode -> IO ()) -> BufferMode -> IO ()
forall a b. (a -> b) -> a -> b
$ Maybe Int -> BufferMode
IO.BlockBuffering Maybe Int
forall a. Maybe a
Nothing
        (Handle, IO ()) -> IO (Handle, IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( Handle
h,
            Handle -> IO ()
IO.hClose Handle
h IO () -> (IOException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`X.catch` \IOException
ex ->
              ByteString -> IO ()
reportError' (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> ByteString
BS.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ IOException -> String
forall a. Show a => a -> String
show (IOException
ex :: X.IOException)
          )
    reportError' :: ByteString -> IO ()
reportError' = (Config -> ByteString -> IO ()
reportError Config
config) (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
LBS.fromStrict

-- | Wait for the process to exit and cleanup its resources.
wait :: Handle -> IO ExitCode
wait :: Handle -> IO ExitCode
wait Handle
handle = do
  Async () -> IO ()
forall a. Async a -> IO ()
cancel (Async () -> IO ()) -> Async () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> Async ()
errorReader Handle
handle
  Process Handle Handle Handle -> IO ExitCode
forall (m :: * -> *) stdin stdout stderr.
MonadIO m =>
Process stdin stdout stderr -> m ExitCode
waitExitCode (Process Handle Handle Handle -> IO ExitCode)
-> Process Handle Handle Handle -> IO ExitCode
forall a b. (a -> b) -> a -> b
$ Handle -> Process Handle Handle Handle
process Handle
handle

-- | Terminate the process, wait for it to actually exit and cleanup its resources.
-- Don't use this if you're manually stopping the solver process by sending an
-- @(exit)@ command. Use `wait` instead.
close :: Handle -> IO ()
close :: Handle -> IO ()
close Handle
handle = do
  Async () -> IO ()
forall a. Async a -> IO ()
cancel (Async () -> IO ()) -> Async () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> Async ()
errorReader Handle
handle
  Process Handle Handle Handle -> IO ()
forall (m :: * -> *) stdin stdout stderr.
MonadIO m =>
Process stdin stdout stderr -> m ()
stopProcess (Process Handle Handle Handle -> IO ())
-> Process Handle Handle Handle -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> Process Handle Handle Handle
process Handle
handle

-- | Create a solver process, use it to make a computation and stop it.
-- Don't use this if you're manually stopping the solver process by sending an
-- @(exit)@ command. Use @\\config -> `System.IO.bracket` (`new` config) `wait`@ instead.
with ::
  -- | The solver process' configuration.
  Config ->
  -- | The computation to run with the solver process
  (Handle -> IO a) ->
  IO a
with :: forall a. Config -> (Handle -> IO a) -> IO a
with Config
config = IO Handle -> (Handle -> IO ()) -> (Handle -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
X.bracket (Config -> IO Handle
new Config
config) Handle -> IO ()
close

infixr 5 :<

pattern (:<) :: Char -> BS.ByteString -> BS.ByteString
pattern c $m:< :: forall {r}.
ByteString -> (Char -> ByteString -> r) -> (Void# -> r) -> r
:< rest <- (BS.uncons -> Just (c, rest))

-- | Make the solver process into an SMT-LIB backend.
toBackend :: Handle -> Backend
toBackend :: Handle -> Backend
toBackend Handle
handle =
  (Builder -> IO ByteString) -> Backend
Backend ((Builder -> IO ByteString) -> Backend)
-> (Builder -> IO ByteString) -> Backend
forall a b. (a -> b) -> a -> b
$ \Builder
cmd -> do
    Handle -> Builder -> IO ()
hPutBuilder (Process Handle Handle Handle -> Handle
forall stdin stdout stderr. Process stdin stdout stderr -> stdin
getStdin (Process Handle Handle Handle -> Handle)
-> Process Handle Handle Handle -> Handle
forall a b. (a -> b) -> a -> b
$ Handle -> Process Handle Handle Handle
process Handle
handle) (Builder -> IO ()) -> Builder -> IO ()
forall a b. (a -> b) -> a -> b
$ Builder
cmd Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
"\n"
    Handle -> IO ()
IO.hFlush (Handle -> IO ()) -> Handle -> IO ()
forall a b. (a -> b) -> a -> b
$ Process Handle Handle Handle -> Handle
forall stdin stdout stderr. Process stdin stdout stderr -> stdin
getStdin (Process Handle Handle Handle -> Handle)
-> Process Handle Handle Handle -> Handle
forall a b. (a -> b) -> a -> b
$ Handle -> Process Handle Handle Handle
process Handle
handle
    Builder -> ByteString
toLazyByteString (Builder -> ByteString) -> IO Builder -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Builder -> ByteString -> IO Builder) -> Builder -> IO Builder
forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine (Int -> Builder -> ByteString -> IO Builder
scanParen Int
0) Builder
forall a. Monoid a => a
mempty
  where
    -- scanParen read lines from the handle's output channel until it has detected
    -- a complete s-expression, i.e. a well-parenthesized word that may contain
    -- strings, quoted symbols, and comments
    -- if we detect a ')' at depth 0 that is not enclosed in a string, a quoted
    -- symbol or a comment, we give up and return immediately
    -- see also the SMT-LIB standard v2.6
    -- https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf#part.2
    scanParen :: Int -> Builder -> BS.ByteString -> IO Builder
    scanParen :: Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth Builder
acc (Char
'(' :< ByteString
more) = Int -> Builder -> ByteString -> IO Builder
scanParen (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Builder
acc ByteString
more
    scanParen Int
depth Builder
acc (Char
'"' :< ByteString
more) = do
      (Builder
acc', ByteString
more') <- Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc ByteString
more
      Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth Builder
acc' ByteString
more'
    scanParen Int
depth Builder
acc (Char
'|' :< ByteString
more) = do
      (Builder
acc', ByteString
more') <- Builder -> ByteString -> IO (Builder, ByteString)
quotedSymbol Builder
acc ByteString
more
      Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth Builder
acc' ByteString
more'
    scanParen Int
depth Builder
acc (Char
';' :< ByteString
_) = (Builder -> ByteString -> IO Builder) -> Builder -> IO Builder
forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine (Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth) Builder
acc
    scanParen Int
depth Builder
acc (Char
')' :< ByteString
more)
      | Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = Builder -> IO Builder
forall (m :: * -> *) a. Monad m => a -> m a
return Builder
acc
      | Bool
otherwise = Int -> Builder -> ByteString -> IO Builder
scanParen (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Builder
acc ByteString
more
    scanParen Int
depth Builder
acc (Char
_ :< ByteString
more) = Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth Builder
acc ByteString
more
    -- mempty case
    scanParen Int
0 Builder
acc ByteString
_ = Builder -> IO Builder
forall (m :: * -> *) a. Monad m => a -> m a
return Builder
acc
    scanParen Int
depth Builder
acc ByteString
_ = (Builder -> ByteString -> IO Builder) -> Builder -> IO Builder
forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine (Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth) Builder
acc

    string :: Builder -> BS.ByteString -> IO (Builder, BS.ByteString)
    string :: Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc (Char
'"' :< Char
'"' :< ByteString
more) = Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc ByteString
more
    string Builder
acc (Char
'"' :< ByteString
more) = (Builder, ByteString) -> IO (Builder, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Builder
acc, ByteString
more)
    string Builder
acc (Char
_ :< ByteString
more) = Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc ByteString
more
    -- mempty case
    string Builder
acc ByteString
_ = (Builder -> ByteString -> IO (Builder, ByteString))
-> Builder -> IO (Builder, ByteString)
forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc

    quotedSymbol :: Builder -> BS.ByteString -> IO (Builder, BS.ByteString)
    quotedSymbol :: Builder -> ByteString -> IO (Builder, ByteString)
quotedSymbol Builder
acc (Char
'|' :< ByteString
more) = (Builder, ByteString) -> IO (Builder, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Builder
acc, ByteString
more)
    quotedSymbol Builder
acc (Char
_ :< ByteString
more) = Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc ByteString
more
    -- mempty case
    quotedSymbol Builder
acc ByteString
_ = (Builder -> ByteString -> IO (Builder, ByteString))
-> Builder -> IO (Builder, ByteString)
forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine Builder -> ByteString -> IO (Builder, ByteString)
quotedSymbol Builder
acc

    continueNextLine :: (Builder -> BS.ByteString -> IO a) -> Builder -> IO a
    continueNextLine :: forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine Builder -> ByteString -> IO a
f Builder
acc = do
      ByteString
next <- Handle -> IO ByteString
BS.hGetLine (Handle -> IO ByteString) -> Handle -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Process Handle Handle Handle -> Handle
forall stdin stdout stderr. Process stdin stdout stderr -> stdout
getStdout (Process Handle Handle Handle -> Handle)
-> Process Handle Handle Handle -> Handle
forall a b. (a -> b) -> a -> b
$ Handle -> Process Handle Handle Handle
process Handle
handle
      Builder -> ByteString -> IO a
f (Builder
acc Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
byteString ByteString
next) ByteString
next