{-# Language DataKinds #-}
{-# Language GADTs #-}
{-# Language PolyKinds #-}
{-# Language ScopedTypeVariables #-}
{-# Language TypeApplications #-}
{-# Language QuasiQuotes #-}

{- |
    Module: EVM.Solvers
    Description: Solver orchestration
-}
module EVM.Solvers where

import Prelude hiding (LT, GT)

import GHC.Natural
import Control.Monad
import GHC.IO.Handle (Handle, hFlush, hSetBuffering, BufferMode(..))
import Control.Concurrent.Chan (Chan, newChan, writeChan, readChan)
import Control.Concurrent (forkIO, killThread)
import Data.Char (isSpace)

import Data.Maybe (fromMaybe)
import Data.Text.Lazy (Text)
import qualified Data.Text as TS
import qualified Data.Text.Lazy as T
import qualified Data.Text.Lazy.IO as T
import Data.Text.Lazy.Builder
import System.Process (createProcess, cleanupProcess, proc, ProcessHandle, std_in, std_out, std_err, StdStream(..))

import EVM.SMT

-- | Supported solvers
data Solver
  = Z3
  | CVC5
  | Bitwuzla
  | Custom Text

instance Show Solver where
  show :: Solver -> String
show Solver
Z3 = String
"z3"
  show Solver
CVC5 = String
"cvc5"
  show Solver
Bitwuzla = String
"bitwuzla"
  show (Custom Text
s) = Text -> String
T.unpack Text
s


-- | A running solver instance
data SolverInstance = SolverInstance
  { SolverInstance -> Solver
_type :: Solver
  , SolverInstance -> Handle
_stdin :: Handle
  , SolverInstance -> Handle
_stdout :: Handle
  , SolverInstance -> Handle
_stderr :: Handle
  , SolverInstance -> ProcessHandle
_process :: ProcessHandle
  }

-- | A channel representing a group of solvers
newtype SolverGroup = SolverGroup (Chan Task)

-- | A script to be executed, a list of models to be extracted in the case of a sat result, and a channel where the result should be written
data Task = Task
  { Task -> SMT2
script :: SMT2
  , Task -> Chan CheckSatResult
resultChan :: Chan CheckSatResult
  }

-- | The result of a call to (check-sat)
data CheckSatResult
  = Sat SMTCex
  | Unsat
  | Unknown
  | Error TS.Text
  deriving (Int -> CheckSatResult -> ShowS
[CheckSatResult] -> ShowS
CheckSatResult -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CheckSatResult] -> ShowS
$cshowList :: [CheckSatResult] -> ShowS
show :: CheckSatResult -> String
$cshow :: CheckSatResult -> String
showsPrec :: Int -> CheckSatResult -> ShowS
$cshowsPrec :: Int -> CheckSatResult -> ShowS
Show, CheckSatResult -> CheckSatResult -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CheckSatResult -> CheckSatResult -> Bool
$c/= :: CheckSatResult -> CheckSatResult -> Bool
== :: CheckSatResult -> CheckSatResult -> Bool
$c== :: CheckSatResult -> CheckSatResult -> Bool
Eq)

isSat :: CheckSatResult -> Bool
isSat :: CheckSatResult -> Bool
isSat (Sat SMTCex
_) = Bool
True
isSat CheckSatResult
_ = Bool
False

isErr :: CheckSatResult -> Bool
isErr :: CheckSatResult -> Bool
isErr (Error Text
_) = Bool
True
isErr CheckSatResult
_ = Bool
False

isUnsat :: CheckSatResult -> Bool
isUnsat :: CheckSatResult -> Bool
isUnsat CheckSatResult
Unsat = Bool
True
isUnsat CheckSatResult
_ = Bool
False

checkSat :: SolverGroup -> SMT2 -> IO CheckSatResult
checkSat :: SolverGroup -> SMT2 -> IO CheckSatResult
checkSat (SolverGroup Chan Task
taskQueue) SMT2
script = do
  -- prepare result channel
  Chan CheckSatResult
resChan <- forall a. IO (Chan a)
newChan

  -- send task to solver group
  forall a. Chan a -> a -> IO ()
writeChan Chan Task
taskQueue (SMT2 -> Chan CheckSatResult -> Task
Task SMT2
script Chan CheckSatResult
resChan)

  -- collect result
  forall a. Chan a -> IO a
readChan Chan CheckSatResult
resChan

withSolvers :: Solver -> Natural -> Maybe Natural -> (SolverGroup -> IO a) -> IO a
withSolvers :: forall a.
Solver -> Natural -> Maybe Natural -> (SolverGroup -> IO a) -> IO a
withSolvers Solver
solver Natural
count Maybe Natural
timeout SolverGroup -> IO a
cont = do
  -- spawn solvers
  [SolverInstance]
instances <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ Solver -> Maybe Natural -> IO SolverInstance
spawnSolver Solver
solver Maybe Natural
timeout) [Natural
1..Natural
count]

  -- spawn orchestration thread
  Chan Task
taskQueue <- forall a. IO (Chan a)
newChan
  Chan SolverInstance
availableInstances <- forall a. IO (Chan a)
newChan
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SolverInstance]
instances (forall a. Chan a -> a -> IO ()
writeChan Chan SolverInstance
availableInstances)
  ThreadId
orchestrateId <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall {b}. Chan Task -> Chan SolverInstance -> IO b
orchestrate Chan Task
taskQueue Chan SolverInstance
availableInstances

  -- run continuation with task queue
  a
res <- SolverGroup -> IO a
cont (Chan Task -> SolverGroup
SolverGroup Chan Task
taskQueue)

  -- cleanup and return results
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SolverInstance -> IO ()
stopSolver [SolverInstance]
instances
  ThreadId -> IO ()
killThread ThreadId
orchestrateId
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res
  where
    orchestrate :: Chan Task -> Chan SolverInstance -> IO b
orchestrate Chan Task
queue Chan SolverInstance
avail = do
      Task
task <- forall a. Chan a -> IO a
readChan Chan Task
queue
      SolverInstance
inst <- forall a. Chan a -> IO a
readChan Chan SolverInstance
avail
      ThreadId
_ <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ Task -> SolverInstance -> Chan SolverInstance -> IO ()
runTask Task
task SolverInstance
inst Chan SolverInstance
avail
      Chan Task -> Chan SolverInstance -> IO b
orchestrate Chan Task
queue Chan SolverInstance
avail

    runTask :: Task -> SolverInstance -> Chan SolverInstance -> IO ()
runTask (Task (SMT2 [Builder]
cmds CexVars
cexvars) Chan CheckSatResult
r) SolverInstance
inst Chan SolverInstance
availableInstances = do
      -- reset solver and send all lines of provided script
      Either Text ()
out <- SolverInstance -> SMT2 -> IO (Either Text ())
sendScript SolverInstance
inst ([Builder] -> CexVars -> SMT2
SMT2 (Builder
"(reset)" forall a. a -> [a] -> [a]
: [Builder]
cmds) CexVars
cexvars)
      case Either Text ()
out of
        -- if we got an error then return it
        Left Text
e -> forall a. Chan a -> a -> IO ()
writeChan Chan CheckSatResult
r (Text -> CheckSatResult
Error (Text
"error while writing SMT to solver: " forall a. Semigroup a => a -> a -> a
<> Text -> Text
T.toStrict Text
e))
        -- otherwise call (check-sat), parse the result, and send it down the result channel
        Right () -> do
          Text
sat <- SolverInstance -> Text -> IO Text
sendLine SolverInstance
inst Text
"(check-sat)"
          CheckSatResult
res <- case Text
sat of
            Text
"sat" -> do
              Map (Expr 'EWord) W256
calldatamodels <- (Text -> Expr 'EWord)
-> (Text -> IO Text) -> [Text] -> IO (Map (Expr 'EWord) W256)
getVars Text -> Expr 'EWord
parseVar (SolverInstance -> Text -> IO Text
getValue SolverInstance
inst) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> Text
T.toStrict CexVars
cexvars.calldataV)
              Map (Expr 'Buf) ByteString
buffermodels <- (Text -> IO Text) -> [Text] -> IO (Map (Expr 'Buf) ByteString)
getBufs (SolverInstance -> Text -> IO Text
getValue SolverInstance
inst) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> Text
T.toStrict CexVars
cexvars.buffersV)
              Map W256 (Map W256 W256)
storagemodels <- (Text -> IO Text)
-> [(Expr 'EWord, Expr 'EWord)] -> IO (Map W256 (Map W256 W256))
getStore (SolverInstance -> Text -> IO Text
getValue SolverInstance
inst) CexVars
cexvars.storeReads
              Map (Expr 'EWord) W256
blockctxmodels <- (Text -> Expr 'EWord)
-> (Text -> IO Text) -> [Text] -> IO (Map (Expr 'EWord) W256)
getVars Text -> Expr 'EWord
parseBlockCtx (SolverInstance -> Text -> IO Text
getValue SolverInstance
inst) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> Text
T.toStrict CexVars
cexvars.blockContextV)
              Map (Expr 'EWord) W256
txctxmodels <- (Text -> Expr 'EWord)
-> (Text -> IO Text) -> [Text] -> IO (Map (Expr 'EWord) W256)
getVars Text -> Expr 'EWord
parseFrameCtx (SolverInstance -> Text -> IO Text
getValue SolverInstance
inst) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> Text
T.toStrict CexVars
cexvars.txContextV)
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SMTCex -> CheckSatResult
Sat forall a b. (a -> b) -> a -> b
$ SMTCex
                { $sel:vars:SMTCex :: Map (Expr 'EWord) W256
vars = Map (Expr 'EWord) W256
calldatamodels
                , $sel:buffers:SMTCex :: Map (Expr 'Buf) ByteString
buffers = Map (Expr 'Buf) ByteString
buffermodels
                , $sel:store:SMTCex :: Map W256 (Map W256 W256)
store = Map W256 (Map W256 W256)
storagemodels
                , $sel:blockContext:SMTCex :: Map (Expr 'EWord) W256
blockContext = Map (Expr 'EWord) W256
blockctxmodels
                , $sel:txContext:SMTCex :: Map (Expr 'EWord) W256
txContext = Map (Expr 'EWord) W256
txctxmodels
                }
            Text
"unsat" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure CheckSatResult
Unsat
            Text
"timeout" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure CheckSatResult
Unknown
            Text
"unknown" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure CheckSatResult
Unknown
            Text
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> CheckSatResult
Error forall a b. (a -> b) -> a -> b
$ Text -> Text
T.toStrict forall a b. (a -> b) -> a -> b
$ Text
"Unable to parse solver output: " forall a. Semigroup a => a -> a -> a
<> Text
sat
          forall a. Chan a -> a -> IO ()
writeChan Chan CheckSatResult
r CheckSatResult
res

      -- put the instance back in the list of available instances
      forall a. Chan a -> a -> IO ()
writeChan Chan SolverInstance
availableInstances SolverInstance
inst

-- | Arguments used when spawing a solver instance
solverArgs :: Solver -> Maybe (Natural) -> [Text]
solverArgs :: Solver -> Maybe Natural -> [Text]
solverArgs Solver
solver Maybe Natural
timeout = case Solver
solver of
  Solver
Bitwuzla -> forall a. HasCallStack => String -> a
error String
"TODO: Bitwuzla args"
  Solver
Z3 ->
    [ Text
"-in" ]
  Solver
CVC5 ->
    [ Text
"--lang=smt"
    , Text
"--no-interactive"
    , Text
"--produce-models"
    , Text
"--tlimit-per=" forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (forall a. Show a => a -> String
show (Natural
1000 forall a. Num a => a -> a -> a
* forall a. a -> Maybe a -> a
fromMaybe Natural
10 Maybe Natural
timeout))
    ]
  Custom Text
_ -> []

-- | Spawns a solver instance, and sets the various global config options that we use for our queries
spawnSolver :: Solver -> Maybe (Natural) -> IO SolverInstance
spawnSolver :: Solver -> Maybe Natural -> IO SolverInstance
spawnSolver Solver
solver Maybe Natural
timeout = do
  let cmd :: CreateProcess
cmd = (String -> [String] -> CreateProcess
proc (forall a. Show a => a -> String
show Solver
solver) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Solver -> Maybe Natural -> [Text]
solverArgs Solver
solver Maybe Natural
timeout)) { std_in :: StdStream
std_in = StdStream
CreatePipe, std_out :: StdStream
std_out = StdStream
CreatePipe, std_err :: StdStream
std_err = StdStream
CreatePipe }
  (Just Handle
stdin, Just Handle
stdout, Just Handle
stderr, ProcessHandle
process) <- CreateProcess
-> IO (Maybe Handle, Maybe Handle, Maybe Handle, ProcessHandle)
createProcess CreateProcess
cmd
  Handle -> BufferMode -> IO ()
hSetBuffering Handle
stdin (Maybe Int -> BufferMode
BlockBuffering (forall a. a -> Maybe a
Just Int
1000000))
  let solverInstance :: SolverInstance
solverInstance = Solver
-> Handle -> Handle -> Handle -> ProcessHandle -> SolverInstance
SolverInstance Solver
solver Handle
stdin Handle
stdout Handle
stderr ProcessHandle
process
  case Maybe Natural
timeout of
    Maybe Natural
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SolverInstance
solverInstance
    Just Natural
t -> case Solver
solver of
        Solver
CVC5 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SolverInstance
solverInstance
        Solver
_ -> do
          ()
_ <- SolverInstance -> Text -> IO ()
sendLine' SolverInstance
solverInstance forall a b. (a -> b) -> a -> b
$ Text
"(set-option :timeout " forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (forall a. Show a => a -> String
show Natural
t) forall a. Semigroup a => a -> a -> a
<> Text
")"
          forall (f :: * -> *) a. Applicative f => a -> f a
pure SolverInstance
solverInstance

-- | Cleanly shutdown a running solver instnace
stopSolver :: SolverInstance -> IO ()
stopSolver :: SolverInstance -> IO ()
stopSolver (SolverInstance Solver
_ Handle
stdin Handle
stdout Handle
stderr ProcessHandle
process) = (Maybe Handle, Maybe Handle, Maybe Handle, ProcessHandle) -> IO ()
cleanupProcess (forall a. a -> Maybe a
Just Handle
stdin, forall a. a -> Maybe a
Just Handle
stdout, forall a. a -> Maybe a
Just Handle
stderr, ProcessHandle
process)

-- | Sends a list of commands to the solver. Returns the first error, if there was one.
sendScript :: SolverInstance -> SMT2 -> IO (Either Text ())
sendScript :: SolverInstance -> SMT2 -> IO (Either Text ())
sendScript SolverInstance
solver (SMT2 [Builder]
cmds CexVars
_) = do
  SolverInstance -> Text -> IO ()
sendLine' SolverInstance
solver ([Text] -> Text
T.unlines forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Builder -> Text
toLazyText [Builder]
cmds)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right()

-- | Sends a single command to the solver, returns the first available line from the output buffer
sendCommand :: SolverInstance -> Text -> IO Text
sendCommand :: SolverInstance -> Text -> IO Text
sendCommand SolverInstance
inst Text
cmd = do
  -- trim leading whitespace
  let cmd' :: Text
cmd' = (Char -> Bool) -> Text -> Text
T.dropWhile Char -> Bool
isSpace Text
cmd
  case Text -> String
T.unpack Text
cmd' of
    String
"" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
"success"      -- ignore blank lines
    Char
';' : String
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
"success" -- ignore comments
    String
_ -> SolverInstance -> Text -> IO Text
sendLine SolverInstance
inst Text
cmd'

-- | Sends a string to the solver and appends a newline, returns the first available line from the output buffer
sendLine :: SolverInstance -> Text -> IO Text
sendLine :: SolverInstance -> Text -> IO Text
sendLine (SolverInstance Solver
_ Handle
stdin Handle
stdout Handle
_ ProcessHandle
_) Text
cmd = do
  Handle -> Text -> IO ()
T.hPutStr Handle
stdin (Text -> Text -> Text
T.append Text
cmd Text
"\n")
  Handle -> IO ()
hFlush Handle
stdin
  Handle -> IO Text
T.hGetLine Handle
stdout

-- | Sends a string to the solver and appends a newline, doesn't return stdout
sendLine' :: SolverInstance -> Text -> IO ()
sendLine' :: SolverInstance -> Text -> IO ()
sendLine' (SolverInstance Solver
_ Handle
stdin Handle
_ Handle
_ ProcessHandle
_) Text
cmd = do
  Handle -> Text -> IO ()
T.hPutStr Handle
stdin (Text -> Text -> Text
T.append Text
cmd Text
"\n")
  Handle -> IO ()
hFlush Handle
stdin

-- | Returns a string representation of the model for the requested variable
getValue :: SolverInstance -> Text -> IO Text
getValue :: SolverInstance -> Text -> IO Text
getValue (SolverInstance Solver
_ Handle
stdin Handle
stdout Handle
_ ProcessHandle
_) Text
var = do
  Handle -> Text -> IO ()
T.hPutStr Handle
stdin (Text -> Text -> Text
T.append (Text -> Text -> Text
T.append Text
"(get-value (" Text
var) Text
"))\n")
  Handle -> IO ()
hFlush Handle
stdin
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Text] -> Text
T.unlines forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse) (Handle -> IO [Text]
readSExpr Handle
stdout)

-- | Reads lines from h until we have a balanced sexpr
readSExpr :: Handle -> IO [Text]
readSExpr :: Handle -> IO [Text]
readSExpr Handle
h = Int64 -> Int64 -> [Text] -> IO [Text]
go Int64
0 Int64
0 []
  where
    go :: Int64 -> Int64 -> [Text] -> IO [Text]
go Int64
0 Int64
0 [Text]
_ = do
      Text
line <- Handle -> IO Text
T.hGetLine Handle
h
      let ls :: Int64
ls = Text -> Int64
T.length forall a b. (a -> b) -> a -> b
$ (Char -> Bool) -> Text -> Text
T.filter (forall a. Eq a => a -> a -> Bool
== Char
'(') Text
line
          rs :: Int64
rs = Text -> Int64
T.length forall a b. (a -> b) -> a -> b
$ (Char -> Bool) -> Text -> Text
T.filter (forall a. Eq a => a -> a -> Bool
== Char
')') Text
line
      if Int64
ls forall a. Eq a => a -> a -> Bool
== Int64
rs
         then forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text
line]
         else Int64 -> Int64 -> [Text] -> IO [Text]
go Int64
ls Int64
rs [Text
line]
    go Int64
ls Int64
rs [Text]
prev = do
      Text
line <- Handle -> IO Text
T.hGetLine Handle
h
      let ls' :: Int64
ls' = Text -> Int64
T.length forall a b. (a -> b) -> a -> b
$ (Char -> Bool) -> Text -> Text
T.filter (forall a. Eq a => a -> a -> Bool
== Char
'(') Text
line
          rs' :: Int64
rs' = Text -> Int64
T.length forall a b. (a -> b) -> a -> b
$ (Char -> Bool) -> Text -> Text
T.filter (forall a. Eq a => a -> a -> Bool
== Char
')') Text
line
      if (Int64
ls forall a. Num a => a -> a -> a
+ Int64
ls') forall a. Eq a => a -> a -> Bool
== (Int64
rs forall a. Num a => a -> a -> a
+ Int64
rs')
         then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text
line forall a. a -> [a] -> [a]
: [Text]
prev
         else Int64 -> Int64 -> [Text] -> IO [Text]
go (Int64
ls forall a. Num a => a -> a -> a
+ Int64
ls') (Int64
rs forall a. Num a => a -> a -> a
+ Int64
rs') (Text
line forall a. a -> [a] -> [a]
: [Text]
prev)