-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SBV.SMT.SMT
-- Copyright   :  (c) Levent Erkok
-- License     :  BSD3
-- Maintainer  :  erkokl@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Abstraction of SMT solvers
-----------------------------------------------------------------------------

{-# LANGUAGE ScopedTypeVariables #-}

module Data.SBV.SMT.SMT where

import Control.Monad(when, zipWithM)
import Control.DeepSeq(NFData(..))
import Data.Char(isSpace)
import Data.List(intercalate)
import Data.Word
import Data.Int
import Data.SBV.BitVectors.Bit
import Data.SBV.BitVectors.Data
import Data.SBV.BitVectors.PrettyNum
import Data.SBV.Utils.TDiff
import System.Directory(findExecutable)
import System.Process(readProcessWithExitCode)
import System.Exit

-- | Solver configuration
data SMTConfig = SMTConfig {
         verbose   :: Bool      -- ^ Debug mode
       , timing    :: Bool      -- ^ Print timing information on how long different phases took (construction, solving, etc.)
       , printBase :: Int       -- ^ Print literals in this base
       , solver    :: SMTSolver -- ^ The actual SMT solver
       }

type SMTEngine = SMTConfig -> [NamedSymVar] -> String -> IO SMTResult

-- | An SMT solver
data SMTSolver = SMTSolver {
         name       :: String    -- ^ Printable name of the solver
       , executable :: String    -- ^ The path to its executable
       , options    :: [String]  -- ^ Options to provide to the solver
       , engine     :: SMTEngine -- ^ The solver engine, responsible for interpreting solver output
       }

-- | The result of an SMT solver call. Each constructor is tagged with
-- the 'SMTConfig' that created it so that further tools can inspect it
-- and build layers of results, if needed. For ordinary uses of the library,
-- this type should not be needed, instead use the accessor functions on
-- it. (Custom Show instances and model extractors.)
data SMTResult = Unsatisfiable SMTConfig                  -- ^ Unsatisfiable
               | Satisfiable   SMTConfig [(String, CW)]   -- ^ Satisfiable with model
               | Unknown       SMTConfig [(String, CW)]   -- ^ Prover returned unknown, with a potential (possibly bogus) model
               | ProofError    SMTConfig [String]         -- ^ Prover errored out
               | TimeOut       SMTConfig                  -- ^ Computation timed out (see the 'timeout' combinator)

resultConfig :: SMTResult -> SMTConfig
resultConfig (Unsatisfiable c) = c
resultConfig (Satisfiable c _) = c
resultConfig (Unknown c _)     = c
resultConfig (ProofError c _)  = c
resultConfig (TimeOut c)       = c

instance NFData SMTResult where
  rnf (Unsatisfiable _)   = ()
  rnf (Satisfiable _ xs)  = rnf xs `seq` ()
  rnf (Unknown _ xs)      = rnf xs `seq` ()
  rnf (ProofError _ xs)   = rnf xs `seq` ()
  rnf (TimeOut _)         = ()

-- | A 'prove' call results in a 'ThmResult'
newtype ThmResult    = ThmResult    SMTResult

-- | A 'sat' call results in a 'SatResult'
-- The reason for having a separate 'SatResult' is to have a more meaningful 'Show' instance.
newtype SatResult    = SatResult    SMTResult

-- | An 'allSat' call results in a 'AllSatResult'
newtype AllSatResult = AllSatResult [SMTResult]

instance Show ThmResult where
  show (ThmResult r) = showSMTResult "Q.E.D."
                                     "Unknown"     "Unknown. Potential counter-example:\n"
                                     "Falsifiable" "Falsifiable. Counter-example:\n" r

instance Show SatResult where
  show (SatResult r) = showSMTResult "Unsatisfiable"
                                     "Unknown"     "Unknown. Potential model:\n"
                                     "Satisfiable" "Satisfiable. Model:\n" r


instance Show AllSatResult where
  show (AllSatResult [])  =  "No solutions found"
  show (AllSatResult [s]) =  "One solution found\n" ++ show (SatResult s)
  show (AllSatResult ss)  =  "Multiple solutions found:\n"       -- shouldn't display how-many; would be too slow/leak-space to compute everything..
                          ++ unlines (zipWith sh [(1::Int)..] ss)
                          ++ "Done."
        where sh i s = showSMTResult "Unsatisfiable"
                                     ("Unknown #" ++ show i ++ "(No assignment to variables returned)") "Unknown. Potential assignment:\n"
                                     ("Solution #" ++ show i ++ " (No assignment to variables returned)") ("Solution #" ++ show i ++ ":\n") s

-- | Instances of 'SatModel' can be automatically extracted from models returned by the
-- solvers. The idea is that the sbv infrastructure provides a stream of 'CW''s (constant-words)
-- coming from the solver, and the type @a@ is interpreted based on these constants. Many typical
-- instances are already provided, so new instances can be declared with relative ease.
--
-- Minimum complete definition: 'parseCWs'
class SatModel a where
  -- | Given a sequence of constant-words, extract one instance of the type @a@, returning
  -- the remaining elements untouched. If the next element is not what's expected for this
  -- type you should return 'Nothing'
  parseCWs  :: [CW] -> Maybe (a, [CW])
  -- | Given a parsed model instance, transform it using @f@, and return the result.
  -- The default definition for this method should be sufficient in most use cases.
  cvtModel  :: (a -> Maybe b) -> Maybe (a, [CW]) -> Maybe (b, [CW])
  cvtModel f x = x >>= \(a, r) -> f a >>= \b -> return (b, r)

instance SatModel Bool where
  parseCWs (W1 i:r) = Just (bit2Bool i,  r)
  parseCWs _        = Nothing

instance SatModel Word8 where
  parseCWs (W8 i:r) = Just (i,  r)
  parseCWs _        = Nothing

instance SatModel Int8 where
  parseCWs (I8 i:r) = Just (i,  r)
  parseCWs _        = Nothing

instance SatModel Word16 where
  parseCWs (W16 i:r) = Just (i,  r)
  parseCWs _         = Nothing

instance SatModel Int16 where
  parseCWs (I16 i:r) = Just (i,  r)
  parseCWs _         = Nothing

instance SatModel Word32 where
  parseCWs (W32 i:r) = Just (i,  r)
  parseCWs _         = Nothing

instance SatModel Int32 where
  parseCWs (I32 i:r) = Just (i,  r)
  parseCWs _         = Nothing

instance SatModel Word64 where
  parseCWs (W64 i:r) = Just (i,  r)
  parseCWs _         = Nothing

instance SatModel Int64 where
  parseCWs (I64 i:r) = Just (i,  r)
  parseCWs _         = Nothing

-- when reading a list; go as long as we can (maximal-munch)
-- note that this never fails..
instance SatModel a => SatModel [a] where
  parseCWs [] = Just ([], [])
  parseCWs xs = case parseCWs xs of
                  Just (a, ys) -> case parseCWs ys of
                                    Just (as, zs) -> Just (a:as, zs)
                                    Nothing       -> Just ([], ys)
                  Nothing     -> Just ([], xs)

instance (SatModel a, SatModel b) => SatModel (a, b) where
  parseCWs as = do (a, bs) <- parseCWs as
                   (b, cs) <- parseCWs bs
                   return ((a, b), cs)

instance (SatModel a, SatModel b, SatModel c) => SatModel (a, b, c) where
  parseCWs as = do (a,      bs) <- parseCWs as
                   ((b, c), ds) <- parseCWs bs
                   return ((a, b, c), ds)

instance (SatModel a, SatModel b, SatModel c, SatModel d) => SatModel (a, b, c, d) where
  parseCWs as = do (a,         bs) <- parseCWs as
                   ((b, c, d), es) <- parseCWs bs
                   return ((a, b, c, d), es)

instance (SatModel a, SatModel b, SatModel c, SatModel d, SatModel e) => SatModel (a, b, c, d, e) where
  parseCWs as = do (a, bs)            <- parseCWs as
                   ((b, c, d, e), fs) <- parseCWs bs
                   return ((a, b, c, d, e), fs)

instance (SatModel a, SatModel b, SatModel c, SatModel d, SatModel e, SatModel f) => SatModel (a, b, c, d, e, f) where
  parseCWs as = do (a, bs)               <- parseCWs as
                   ((b, c, d, e, f), gs) <- parseCWs bs
                   return ((a, b, c, d, e, f), gs)

instance (SatModel a, SatModel b, SatModel c, SatModel d, SatModel e, SatModel f, SatModel g) => SatModel (a, b, c, d, e, f, g) where
  parseCWs as = do (a, bs)                  <- parseCWs as
                   ((b, c, d, e, f, g), hs) <- parseCWs bs
                   return ((a, b, c, d, e, f, g), hs)

-- | Given an 'SMTResult', extract an arbitrarily typed model from it, given a 'SatModel' instance
getModel :: SatModel a => SMTResult -> a
getModel (Unsatisfiable _) = error "SatModel.getModel: Unsatisfiable result"
getModel (Unknown _ _)     = error "Impossible! Backend solver returned unknown for Bit-vector problem!"
getModel (ProofError _ s)  = error $ unlines $ "An error happened: " : s
getModel (TimeOut _)       = error $ "Timeout"
getModel (Satisfiable _ m) = case parseCWs [c | (_, c) <- m] of
                               Just (x, []) -> x
                               Just (_, ys) -> error $ "SBV.getModel: Partially constructed model; remaining elements: " ++ show ys
                               Nothing      -> error $ "SBV.getModel: Cannot construct a model from: " ++ show m

-- | Given an 'allSat' call, we typically want to iterate over it and print the results in sequence. The
-- 'displayModels' function automates this task by calling 'disp' on each result, consecutively. The first
-- 'Int' argument to 'disp' 'is the current model number.
displayModels :: SatModel a => (Int -> a -> IO ()) -> AllSatResult -> IO Int
displayModels disp (AllSatResult ms) = do
    inds <- zipWithM display (map getModel ms) [(1::Int)..]
    return $ last (0:inds)
  where display r i = disp i r >> return i

showSMTResult :: String -> String -> String -> String -> String -> SMTResult -> String
showSMTResult unsatMsg unkMsg unkMsgModel satMsg satMsgModel result = case result of
  Unsatisfiable _  -> unsatMsg
  Satisfiable _ [] -> satMsg
  Satisfiable _ m  -> satMsgModel ++ intercalate "\n" (map (shM cfg) m)
  Unknown _ []     -> unkMsg
  Unknown _ m      -> unkMsgModel ++ intercalate "\n" (map (shM cfg) m)
  ProofError _ []  -> "*** An error occurred. No additional information available. Try running in verbose mode"
  ProofError _ ls  -> "*** An error occurred.\n" ++ intercalate "\n" (map ("***  " ++) ls)
  TimeOut _        -> "*** Timeout"
 where cfg = resultConfig result

shM :: SMTConfig -> (String, CW) -> String
shM cfg (s, v) = "  " ++ s ++ " = " ++ sh (printBase cfg) v
  where sh 2  = binS
        sh 10 = show
        sh 16 = hexS
        sh n  = \w -> show w ++ " -- Ignoring unsupported printBase " ++ show n ++ ", use 2, 10, or 16."

pipeProcess :: String -> String -> [String] -> String -> IO (Either String [String])
pipeProcess nm execName opts script = do
        mbExecPath <- findExecutable execName
        case mbExecPath of
          Nothing -> return $ Left $ "Unable to locate executable for " ++ nm
                                   ++ "\nExecutable specified: " ++ show execName
          Just execPath -> do (ec, contents, errors) <- readProcessWithExitCode execPath opts script
                              case ec of
                                ExitSuccess  ->  if null errors
                                                 then return $ Right $ map clean (filter (not . null) (lines contents))
                                                 else return $ Left errors
                                ExitFailure n -> return $ Left $  "Failed to invoke " ++ nm
                                                               ++ "\nExecutable: " ++ show execPath
                                                               ++ "\nOptions   : " ++ unwords opts
                                                               ++ "\nExit code : " ++ show n
  where clean = reverse . dropWhile isSpace . reverse . dropWhile isSpace

standardSolver :: SMTConfig -> String -> ([String] -> a) -> ([String] -> a) -> IO a
standardSolver config script failure success = do
    let msg      = when (verbose config) . putStrLn . ("** " ++)
        smtSolver= solver config
        exec     = executable smtSolver
        opts     = options smtSolver
        isTiming = timing config
        nmSolver = name smtSolver
    msg $ "Calling: " ++ show (unwords (exec:opts))
    contents <- timeIf isTiming nmSolver $ pipeProcess nmSolver exec opts script
    msg $ nmSolver ++ " output:\n" ++ either id (intercalate "\n") contents
    case contents of
      Left e   -> return $ failure (lines e)
      Right xs -> return $ success xs