{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE DefaultSignatures     #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE NamedFieldPuns        #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# OPTIONS_GHC -fno-warn-orphans  #-}
module Data.SBV.Control.Utils (
       io
     , ask, send, getValue, getUninterpretedValue, getValueCW, getUnsatAssumptions, SMTValue(..)
     , getQueryState, modifyQueryState, getConfig, getObjectives, getSBVAssertions, getQuantifiedInputs
     , checkSat, checkSatUsing, getAllSatResult
     , inNewContext, freshVar, freshVar_
     , parse
     , unexpected
     , timeout
     , queryDebug
     , retrieveResponse
     , runProofOn
     ) where
import Data.List  (sortBy, elemIndex, partition, groupBy, tails)
import Data.Ord      (comparing)
import Data.Function (on)
import Data.Int
import Data.Word
import qualified Data.Map as Map
import Control.Monad            (unless)
import Control.Monad.State.Lazy (get, liftIO)
import Data.IORef (readIORef, writeIORef)
import Data.Time (getZonedTime)
import Data.SBV.Core.Data     ( SW(..), CW(..), SBV, AlgReal, sbvToSW, kindOf, Kind(..)
                              , HasKind(..), mkConstCW, CWVal(..), SMTResult(..)
                              , NamedSymVar, SMTConfig(..), Query, SMTModel(..)
                              , QueryState(..), SVal(..), Quantifier(..), cache
                              , newExpr, SBVExpr(..), Op(..), FPOp(..), SBV(..)
                              , SolverContext(..), SBool, Objective(..), SolverCapabilities(..), capabilities
                              , Result(..), SMTProblem(..), trueSW, SymWord(..)
                              )
import Data.SBV.Core.Symbolic (IncState(..), withNewIncState, State(..), svToSW, registerLabel, svMkSymVar)
import Data.SBV.Core.AlgReals   (mergeAlgReals)
import Data.SBV.Core.Operations (svNot, svNotEqual, svOr)
import Data.SBV.SMT.SMTLib  (toIncSMTLib, toSMTLib)
import Data.SBV.SMT.Utils   (showTimeoutValue, annotateWithName, alignDiagnostic, alignPlain, debug, mergeSExpr)
import Data.SBV.Utils.SExpr
import Data.SBV.Control.Types
import qualified Data.Set as Set (toList)
import GHC.Stack
instance SolverContext Query where
   constrain          = addQueryConstraint Nothing
   namedConstraint nm = addQueryConstraint (Just nm)
   setOption o
     | isStartModeOption o = error $ unlines [ ""
                                             , "*** Data.SBV: '" ++ show o ++ "' can only be set at start-up time."
                                             , "*** Hint: Move the call to 'setOption' before the query."
                                             ]
     | True                = send True $ setSMTOption o
addQueryConstraint :: Maybe String -> SBool -> Query ()
addQueryConstraint mbNm b = do sw <- inNewContext (\st -> do maybe (return ()) (registerLabel st) mbNm
                                                             sbvToSW st b)
                               send True $ "(assert " ++ mkNamed mbNm (show sw)  ++ ")"
   where mkNamed Nothing   s = s
         mkNamed (Just nm) s = annotateWithName nm s
getConfig :: Query SMTConfig
getConfig = queryConfig <$> getQueryState
getObjectives :: Query [Objective (SW, SW)]
getObjectives = do State{rOptGoals} <- get
                   io $ reverse <$> readIORef rOptGoals
getSBVAssertions :: Query [(String, Maybe CallStack, SW)]
getSBVAssertions = do State{rAsserts} <- get
                      io $ reverse <$> readIORef rAsserts
io :: IO a -> Query a
io = liftIO
syncUpSolver :: IncState -> Query ()
syncUpSolver is = do
        cfg <- getConfig
        ls  <- io $ do let swapc ((_, a), b)   = (b, a)
                           cmp   (a, _) (b, _) = a `compare` b
                       inps  <- reverse <$> readIORef (rNewInps is)
                       ks    <- readIORef (rNewKinds is)
                       cnsts <- sortBy cmp . map swapc . Map.toList <$> readIORef (rNewConsts is)
                       as    <- readIORef (rNewAsgns is)
                       return $ toIncSMTLib cfg inps ks cnsts as cfg
        mapM_ (send True) (mergeSExpr ls)
getQueryState :: Query QueryState
getQueryState = do state <- get
                   mbQS  <- io $ readIORef (queryState state)
                   case mbQS of
                     Nothing -> error $ unlines [ ""
                                                , "*** Data.SBV: Impossible happened: Query context required in a non-query mode."
                                                , "Please report this as a bug!"
                                                ]
                     Just qs -> return qs
modifyQueryState :: (QueryState -> QueryState) -> Query ()
modifyQueryState f = do state <- get
                        mbQS  <- io $ readIORef (queryState state)
                        case mbQS of
                          Nothing -> error $ unlines [ ""
                                                     , "*** Data.SBV: Impossible happened: Query context required in a non-query mode."
                                                     , "Please report this as a bug!"
                                                     ]
                          Just qs -> let fqs = f qs
                                     in fqs `seq` io $ writeIORef (queryState state) $ Just fqs
inNewContext :: (State -> IO a) -> Query a
inNewContext act = do st <- get
                      (is, r) <- io $ withNewIncState st act
                      syncUpSolver is
                      return r
freshVar_ :: forall a. SymWord a => Query (SBV a)
freshVar_ = inNewContext $ fmap SBV . svMkSymVar (Just EX) k Nothing
  where k = kindOf (undefined :: a)
freshVar :: forall a. SymWord a => String -> Query (SBV a)
freshVar nm = inNewContext $ fmap SBV . svMkSymVar (Just EX) k (Just nm)
  where k = kindOf (undefined :: a)
queryDebug :: [String] -> Query ()
queryDebug msgs = do QueryState{queryConfig} <- getQueryState
                     io $ debug queryConfig msgs
ask :: String -> Query String
ask s = do QueryState{queryAsk, queryTimeOutValue} <- getQueryState
           case queryTimeOutValue of
             Nothing -> queryDebug ["[SEND] " `alignPlain` s]
             Just i  -> queryDebug ["[SEND, TimeOut: " ++ showTimeoutValue i ++ "] " `alignPlain` s]
           r <- io $ queryAsk queryTimeOutValue s
           queryDebug ["[RECV] " `alignPlain` r]
           return r
send :: Bool -> String -> Query ()
send requireSuccess s = do
            QueryState{queryAsk, querySend, queryConfig, queryTimeOutValue} <- getQueryState
            if requireSuccess && supportsCustomQueries (capabilities (solver queryConfig))
               then do r <- io $ queryAsk queryTimeOutValue s
                       case words r of
                         ["success"] -> queryDebug ["[GOOD] " `alignPlain` s]
                         _           -> do case queryTimeOutValue of
                                             Nothing -> queryDebug ["[FAIL] " `alignPlain` s]
                                             Just i  -> queryDebug [("[FAIL, TimeOut: " ++ showTimeoutValue i ++ "]  ") `alignPlain` s]
                                           unexpected "Command" s "success" Nothing r Nothing
               else io $ querySend queryTimeOutValue s  
retrieveResponse :: String -> Maybe Int -> Query [String]
retrieveResponse userTag mbTo = do
             ts  <- io (show <$> getZonedTime)
             let synchTag = show $ userTag ++ " (at: " ++ ts ++ ")"
                 cmd = "(echo " ++ synchTag ++ ")"
             queryDebug ["[SYNC] Attempting to synchronize with tag: " ++ synchTag]
             send False cmd
             QueryState{queryRetrieveResponse} <- getQueryState
             let loop sofar = do
                  s <- io $ queryRetrieveResponse mbTo
                  
                  
                  
                  if s == synchTag || show s == synchTag
                     then do queryDebug ["[SYNC] Synchronization achieved using tag: " ++ synchTag]
                             return $ reverse sofar
                     else do queryDebug ["[RECV] " `alignPlain` s]
                             loop (s : sofar)
             loop []
class SMTValue a where
  sexprToVal :: SExpr -> Maybe a
  default sexprToVal :: Read a => SExpr -> Maybe a
  sexprToVal (ECon c) = case reads c of
                          [(v, "")] -> Just v
                          _         -> Nothing
  sexprToVal _        = Nothing
fromIntegralToVal :: Integral a => SExpr -> Maybe a
fromIntegralToVal (ENum (i, _)) = Just $ fromIntegral i
fromIntegralToVal _             = Nothing
instance SMTValue Int8    where sexprToVal = fromIntegralToVal
instance SMTValue Int16   where sexprToVal = fromIntegralToVal
instance SMTValue Int32   where sexprToVal = fromIntegralToVal
instance SMTValue Int64   where sexprToVal = fromIntegralToVal
instance SMTValue Word8   where sexprToVal = fromIntegralToVal
instance SMTValue Word16  where sexprToVal = fromIntegralToVal
instance SMTValue Word32  where sexprToVal = fromIntegralToVal
instance SMTValue Word64  where sexprToVal = fromIntegralToVal
instance SMTValue Integer where sexprToVal = fromIntegralToVal
instance SMTValue Float where
   sexprToVal (EFloat f) = Just f
   sexprToVal _          = Nothing
instance SMTValue Double where
   sexprToVal (EDouble f) = Just f
   sexprToVal _           = Nothing
instance SMTValue Bool where
   sexprToVal (ENum (1, _)) = Just True
   sexprToVal (ENum (0, _)) = Just False
   sexprToVal _             = Nothing
instance SMTValue AlgReal where
   sexprToVal (EReal a) = Just a
   sexprToVal _         = Nothing
getValue :: SMTValue a => SBV a -> Query a
getValue s = do sw <- inNewContext (`sbvToSW` s)
                let nm  = show sw
                    cmd = "(get-value (" ++ nm ++ "))"
                    bad = unexpected "getValue" cmd "a model value" Nothing
                r <- ask cmd
                parse r bad $ \case EApp [EApp [ECon o,  v]] | o == show sw -> case sexprToVal v of
                                                                                 Nothing -> bad r Nothing
                                                                                 Just c  -> return c
                                    _                                       -> bad r Nothing
getUninterpretedValue :: HasKind a => SBV a -> Query String
getUninterpretedValue s =
        case kindOf s of
          KUserSort _ (Left _) -> do sw <- inNewContext (`sbvToSW` s)
                                     let nm  = show sw
                                         cmd = "(get-value (" ++ nm ++ "))"
                                         bad = unexpected "getValue" cmd "a model value" Nothing
                                     r <- ask cmd
                                     parse r bad $ \case EApp [EApp [ECon o,  ECon v]] | o == show sw -> return v
                                                         _                                             -> bad r Nothing
          k                    -> error $ unlines [""
                                                  , "*** SBV.getUninterpretedValue: Called on an 'interpreted' kind"
                                                  , "*** "
                                                  , "***    Kind: " ++ show k
                                                  , "***    Hint: Use 'getValue' to extract value for interpreted kinds."
                                                  , "*** "
                                                  , "*** Only truly uninterpreted sorts should be used with 'getUninterpretedValue.'"
                                                  ]
getValueCWHelper :: Maybe Int -> SW -> Query CW
getValueCWHelper mbi s = do
       let nm  = show s
           k   = kindOf s
           modelIndex = case mbi of
                          Nothing -> ""
                          Just i  -> " :model_index " ++ show i
           cmd        = "(get-value (" ++ nm ++ ")" ++ modelIndex ++ ")"
           bad = unexpected "getModel" cmd ("a value binding for kind: " ++ show k) Nothing
           getUIIndex (KUserSort  _ (Right xs)) i = i `elemIndex` xs
           getUIIndex _                         _ = Nothing
       r <- ask cmd
       let isIntegral sw = isBoolean sw || isBounded sw || isInteger sw
           extract (ENum    i) | isIntegral      s = return $ mkConstCW  k (fst i)
           extract (EReal   i) | isReal          s = return $ CW KReal   (CWAlgReal i)
           extract (EFloat  i) | isFloat         s = return $ CW KFloat  (CWFloat   i)
           extract (EDouble i) | isDouble        s = return $ CW KDouble (CWDouble  i)
           extract (ECon    i) | isUninterpreted s = return $ CW k       (CWUserSort (getUIIndex k i, i))
           extract _                               = bad r Nothing
       parse r bad $ \case EApp [EApp [ECon v, val]] | v == nm -> extract val
                           _                                   -> bad r Nothing
getValueCW :: Maybe Int -> SW -> Query CW
getValueCW mbi s
  | kindOf s /= KReal
  = getValueCWHelper mbi s
  | True
  = do cfg <- getConfig
       if not (supportsApproxReals (capabilities (solver cfg)))
          then getValueCWHelper mbi s
          else do send True "(set-option :pp.decimal false)"
                  rep1 <- getValueCWHelper mbi s
                  send True   "(set-option :pp.decimal true)"
                  send True $ "(set-option :pp.decimal_precision " ++ show (printRealPrec cfg) ++ ")"
                  rep2 <- getValueCWHelper mbi s
                  let bad = unexpected "getValueCW" "get-value" ("a real-valued binding for " ++ show s) Nothing (show (rep1, rep2)) Nothing
                  case (rep1, rep2) of
                    (CW KReal (CWAlgReal a), CW KReal (CWAlgReal b)) -> return $ CW KReal (CWAlgReal (mergeAlgReals ("Cannot merge real-values for " ++ show s) a b))
                    _                                                -> bad
checkSat :: Query CheckSatResult
checkSat = do cfg <- getConfig
              checkSatUsing $ satCmd cfg
checkSatUsing :: String -> Query CheckSatResult
checkSatUsing cmd = do let bad = unexpected "checkSat" cmd "one of sat/unsat/unknown" Nothing
                       r <- ask cmd
                       parse r bad $ \case ECon "sat"     -> return Sat
                                           ECon "unsat"   -> return Unsat
                                           ECon "unknown" -> return Unk
                                           _              -> bad r Nothing
getQuantifiedInputs :: Query [(Quantifier, NamedSymVar)]
getQuantifiedInputs = do State{rinps} <- get
                         liftIO $ reverse <$> readIORef rinps
getAllSatResult :: Query (Bool, Bool, [SMTResult])
getAllSatResult = do queryDebug ["*** Checking Satisfiability, all solutions.."]
                     cfg <- getConfig
                     State{rUsedKinds} <- get
                     ki    <- liftIO $ readIORef rUsedKinds
                     qinps <- getQuantifiedInputs
                     let usorts = [s | us@(KUserSort s _) <- Set.toList ki, isFree us]
                     unless (null usorts) $ queryDebug [ "*** SBV.allSat: Uninterpreted sorts present: " ++ unwords usorts
                                                       , "***             SBV will use equivalence classes to generate all-satisfying instances."
                                                       ]
                     let vars :: [(SVal, NamedSymVar)]
                         vars = let allModelInputs = takeWhile ((/= ALL) . fst) qinps
                                    sortByNodeId :: [NamedSymVar] -> [NamedSymVar]
                                    sortByNodeId = sortBy (compare `on` (\(SW _ n, _) -> n))
                                    mkSVal :: NamedSymVar -> (SVal, NamedSymVar)
                                    mkSVal nm@(sw, _) = (SVal (kindOf sw) (Right (cache (const (return sw)))), nm)
                                in map mkSVal $ sortByNodeId [nv | (_, nv@(_, n)) <- allModelInputs, not (isNonModelVar cfg n)]
                         
                         w = ALL `elem` map fst qinps
                     (sc, ms) <- loop vars cfg
                     return (sc, w, reverse ms)
   where isFree (KUserSort _ (Left _)) = True
         isFree _                      = False
         loop vars cfg = go (1::Int) []
            where go :: Int -> [SMTResult] -> Query (Bool, [SMTResult])
                  go !cnt sofar
                   | Just maxModels <- allSatMaxModelCount cfg, cnt > maxModels
                   = do queryDebug ["*** Maximum model count request of " ++ show maxModels ++ " reached, stopping the search."]
                        return (True, sofar)
                   | True
                   = do queryDebug ["Looking for solution " ++ show cnt]
                        cs <- checkSat
                        case cs of
                          Unsat -> return (False, sofar)
                          Unk   -> do queryDebug ["*** Solver returned unknown, terminating query."]
                                      return (False, sofar)
                          Sat   -> do assocs <- mapM (\(sval, (sw, n)) -> do cw <- getValueCW Nothing sw
                                                                             return (n, (sval, cw))) vars
                                      let m = Satisfiable cfg SMTModel { modelObjectives = []
                                                                       , modelAssocs     = [(n, cw) | (n, (_, cw)) <- assocs]
                                                                       }
                                          (interpreteds, uninterpreteds) = partition (not . isFree . kindOf . fst) (map snd assocs)
                                          
                                          
                                          
                                          interpretedEqs :: [SVal]
                                          interpretedEqs = [mkNotEq (kindOf sv) sv (SVal (kindOf sv) (Left cw)) | (sv, cw) <- interpreteds]
                                             where mkNotEq k a b
                                                    | isDouble k || isFloat k = svNot (a `fpNotEq` b)
                                                    | True                    = a `svNotEqual` b
                                                   fpNotEq a b = SVal KBool $ Right $ cache r
                                                       where r st = do swa <- svToSW st a
                                                                       swb <- svToSW st b
                                                                       newExpr st KBool (SBVApp (IEEEFP FP_ObjEqual) [swa, swb])
                                          
                                          uninterpretedEqs :: [SVal]
                                          uninterpretedEqs = concatMap pwDistinct         
                                                           . filter (\l -> length l > 1)  
                                                           . map (map fst)                
                                                           . groupBy ((==) `on` snd)      
                                                           . sortBy (comparing snd)       
                                                           $ uninterpreteds
                                            where pwDistinct :: [SVal] -> [SVal]
                                                  pwDistinct ss = [x `svNotEqual` y | (x:ys) <- tails ss, y <- ys]
                                          eqs = interpretedEqs ++ uninterpretedEqs
                                          disallow = case eqs of
                                                       [] -> Nothing
                                                       _  -> Just $ SBV $ foldr1 svOr eqs
                                      let resultsSoFar = m : sofar
                                      
                                      case disallow of
                                        Nothing -> return (False, resultsSoFar)
                                        Just d  -> do constrain d
                                                      go (cnt+1) resultsSoFar
getUnsatAssumptions :: [String] -> [(String, a)] -> Query [a]
getUnsatAssumptions originals proxyMap = do
        let cmd = "(get-unsat-assumptions)"
            bad = unexpected "getUnsatAssumptions" cmd "a list of unsatisfiable assumptions"
                           $ Just [ "Make sure you use:"
                                  , ""
                                  , "       setOption $ ProduceUnsatAssumptions True"
                                  , ""
                                  , "to make sure the solver is ready for producing unsat assumptions,"
                                  , "and that there is a model by first issuing a 'checkSat' call."
                                  ]
            fromECon (ECon s) = Just s
            fromECon _        = Nothing
        r <- ask cmd
        
        
        
        
        let walk []     sofar = return $ reverse sofar
            walk (a:as) sofar = case a `lookup` proxyMap of
                                  Just v  -> walk as (v:sofar)
                                  Nothing -> do queryDebug [ "*** In call to 'getUnsatAssumptions'"
                                                           , "***"
                                                           , "***    Unexpected assumption named: " ++ show a
                                                           , "***    Was expecting one of       : " ++ show originals
                                                           , "***"
                                                           , "*** This can happen if unsat-cores are also enabled. Ignoring."
                                                           ]
                                                walk as sofar
        parse r bad $ \case
           EApp es | Just xs <- mapM fromECon es -> walk xs []
           _                                     -> bad r Nothing
timeout :: Int -> Query a -> Query a
timeout n q = do modifyQueryState (\qs -> qs {queryTimeOutValue = Just n})
                 r <- q
                 modifyQueryState (\qs -> qs {queryTimeOutValue = Nothing})
                 return r
parse :: String -> (String -> Maybe [String] -> a) -> (SExpr -> a) -> a
parse r fCont sCont = case parseSExpr r of
                        Left  e   -> fCont r (Just [e])
                        Right res -> sCont res
unexpected :: String -> String -> String -> Maybe [String] -> String -> Maybe [String] -> Query a
unexpected ctx sent expected mbHint received mbReason = do
        
        extras <- retrieveResponse "terminating upon unexpected response" (Just 5000000)
        error $ unlines $ [ ""
                          , "*** Data.SBV: Unexpected response from the solver."
                          , "***    Context : " `alignDiagnostic` ctx
                          , "***    Sent    : " `alignDiagnostic` sent
                          , "***    Expected: " `alignDiagnostic` expected
                          , "***    Received: " `alignDiagnostic` unlines (received : extras)
                          ]
                       ++ [ "***    Reason  : " `alignDiagnostic` unlines r | Just r <- [mbReason]]
                       ++ [ "***    Hint    : " `alignDiagnostic` unlines r | Just r <- [mbHint]]
runProofOn :: SMTConfig -> Bool -> [String] -> Result -> SMTProblem
runProofOn config isSat comments res@(Result ki _qcInfo _codeSegs is consts tbls arrs uis axs pgm cstrs _assertions outputs) =
     let flipQ (ALL, x) = (EX,  x)
         flipQ (EX,  x) = (ALL, x)
         skolemize :: [(Quantifier, NamedSymVar)] -> [Either SW (SW, [SW])]
         skolemize quants = go quants ([], [])
           where go []                   (_,  sofar) = reverse sofar
                 go ((ALL, (v, _)):rest) (us, sofar) = go rest (v:us, Left v : sofar)
                 go ((EX,  (v, _)):rest) (us, sofar) = go rest (us,   Right (v, reverse us) : sofar)
         qinps      = if isSat then is else map flipQ is
         skolemMap  = skolemize qinps
         o = case outputs of
               []  -> trueSW
               [so] -> case so of
                        SW KBool _ -> so
                        _          -> trueSW
                                      
               os  -> error $ unlines [ "User error: Multiple output values detected: " ++ show os
                                      , "Detected while generating the trace:\n" ++ show res
                                      , "*** Check calls to \"output\", they are typically not needed!"
                                      ]
     in SMTProblem { smtLibPgm = toSMTLib config ki isSat comments is skolemMap consts tbls arrs uis axs pgm cstrs o }
{-# ANN module ("HLint: ignore Reduce duplication" :: String) #-}