{-# LANGUAGE DeriveGeneric     #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE BangPatterns      #-}

-- | This is a wrapper around IO that permits SMT queries

module Language.Fixpoint.Solver.Monad
       ( -- * Type
         SolveM

         -- * Execution
       , runSolverM

         -- * Get Binds
       , getBinds

         -- * SMT Query
       , filterValid
       , checkSat, smtEnablrmbqi

         -- * Debug
       , Stats
       , tickIter
       , stats
       , numIter
       )
       where

import           Control.DeepSeq
import           GHC.Generics
import           Language.Fixpoint.Utils.Progress
import           Language.Fixpoint.Misc    (groupList)
import qualified Language.Fixpoint.Types.Config  as C
import           Language.Fixpoint.Types.Config  (Config, solver, linear, SMTSolver(Z3))
import qualified Language.Fixpoint.Types   as F
import           Language.Fixpoint.Types   (pprint)
import qualified Language.Fixpoint.Types.Errors  as E
import qualified Language.Fixpoint.Smt.Theories as Thy
import           Language.Fixpoint.Smt.Serialize (initSMTEnv)
import           Language.Fixpoint.Types.PrettyPrint ()
import           Language.Fixpoint.Smt.Interface
import           Language.Fixpoint.Solver.Validate
-- import           Language.Fixpoint.Solver.Solution
import           Data.Maybe           (isJust, catMaybes)
import           Text.PrettyPrint.HughesPJ (text)
import           Control.Monad.State.Strict
import qualified Data.HashMap.Strict as M

import           Control.Exception.Base (bracket)

---------------------------------------------------------------------------
-- | Solver Monadic API ---------------------------------------------------
---------------------------------------------------------------------------

type SolveM = StateT SolverState IO

data SolverState = SS { ssCtx     :: !Context          -- ^ SMT Solver Context
                      , ssBinds   :: !F.BindEnv        -- ^ All variables and types
                      , ssStats   :: !Stats            -- ^ Solver Statistics
                      }

data Stats = Stats { numCstr :: !Int -- ^ # Horn Constraints
                   , numIter :: !Int -- ^ # Refine Iterations
                   , numBrkt :: !Int -- ^ # smtBracket    calls (push/pop)
                   , numChck :: !Int -- ^ # smtCheckUnsat calls
                   , numVald :: !Int -- ^ # times SMT said RHS Valid
                   } deriving (Show, Generic)

instance NFData Stats

stats0    :: F.GInfo c b -> Stats
stats0 fi = Stats nCs 0 0 0 0
  where
    nCs   = M.size $ F.cm fi

instance F.PTable Stats where
  ptable s = F.DocTable [ (text "# Constraints"         , pprint (numCstr s))
                        , (text "# Refine Iterations"   , pprint (numIter s))
                        , (text "# SMT Brackets"        , pprint (numBrkt s))
                        , (text "# SMT Queries (Valid)" , pprint (numVald s))
                        , (text "# SMT Queries (Total)" , pprint (numChck s))
                        ]

---------------------------------------------------------------------------
runSolverM :: Config -> F.GInfo c b -> Int -> SolveM a -> IO a
---------------------------------------------------------------------------
runSolverM cfg fi' _ act = do
  bracket acquire release $ \ctx -> do
    res <- runStateT (declareInitEnv >> declare fi >> act) (SS ctx be $ stats0 fi)
    smtWrite ctx "(exit)"
    return $ fst res

  where
    acquire = makeContextWithSEnv lar (solver cfg) file env
    release = cleanupContext
    be      = F.bs     fi
    file    = F.fileName fi -- (inFile cfg)
    env     = F.fromListSEnv ((F.toListSEnv $ F.lits fi) ++ binds)
    binds   = [(x, F.sr_sort t) | (_, x, t) <- F.bindEnvToList $ F.bs fi]
    -- only linear arithmentic when: linear flag is on or solver /= Z3
    lar     = linear cfg || Z3 /= solver cfg
    fi      = fi' {F.allowHO = C.allowHO cfg}

---------------------------------------------------------------------------
getBinds :: SolveM F.BindEnv
---------------------------------------------------------------------------
getBinds = ssBinds <$> get

---------------------------------------------------------------------------
getIter :: SolveM Int
---------------------------------------------------------------------------
getIter = numIter . ssStats <$> get

---------------------------------------------------------------------------
incIter, incBrkt :: SolveM ()
---------------------------------------------------------------------------
incIter   = modifyStats $ \s -> s {numIter = 1 + numIter s}
incBrkt   = modifyStats $ \s -> s {numBrkt = 1 + numBrkt s}

---------------------------------------------------------------------------
incChck, incVald :: Int -> SolveM ()
---------------------------------------------------------------------------
incChck n = modifyStats $ \s -> s {numChck = n + numChck s}
incVald n = modifyStats $ \s -> s {numVald = n + numVald s}

withContext :: (Context -> IO a) -> SolveM a
withContext k = (lift . k) =<< getContext

getContext :: SolveM Context
getContext = ssCtx <$> get

modifyStats :: (Stats -> Stats) -> SolveM ()
modifyStats f = modify $ \s -> s { ssStats = f (ssStats s) }

---------------------------------------------------------------------------
-- | SMT Interface --------------------------------------------------------
---------------------------------------------------------------------------
filterValid :: F.Expr -> F.Cand a -> SolveM [a]
---------------------------------------------------------------------------
filterValid p qs = do
  qs' <- withContext $ \me ->
           smtBracket me $
             filterValid_ p qs me
  -- stats
  incBrkt
  incChck (length qs)
  incVald (length qs')
  return qs'

filterValid_ :: F.Expr -> F.Cand a -> Context -> IO [a]
filterValid_ p qs me = catMaybes <$> do
  smtAssert me p
  forM qs $ \(q, x) ->
    smtBracket me $ do
      smtAssert me (F.PNot q)
      valid <- smtCheckUnsat me
      return $ if valid then Just x else Nothing


smtEnablrmbqi
  = withContext $ \me ->
            smtWrite me "(set-option :smt.mbqi true)"


checkSat :: F.Expr -> SolveM  Bool
checkSat p
  = withContext $ \me ->
            smtBracket me $
             smtCheckSat me p

---------------------------------------------------------------------------
declare :: F.GInfo c a -> SolveM ()
---------------------------------------------------------------------------
declareInitEnv :: SolveM ()
declareInitEnv = withContext $ \me ->
                   forM_ (F.toListSEnv initSMTEnv) $ uncurry $ smtDecl me

declare fi  = withContext $ \me -> do
  xts      <- either E.die return $ declSymbols fi
  let ess   = declLiterals fi
  forM_ xts $ uncurry $ smtDecl     me
  forM_ ess $           smtDistinct me

declLiterals :: F.GInfo c a -> [[F.Expr]]
declLiterals fi | F.allowHO fi
  = [es | (_, es) <- tess ]
  where
    tess        = groupList [(t, F.expr x) | (x, t) <- F.toListSEnv $ F.lits fi, not (isThy x)]
    isThy       = isJust . Thy.smt2Symbol

declLiterals fi
  = [es | (_, es) <- tess ]
   where
    notFun      = not . F.isFunctionSortedReft . (`F.RR` F.trueReft)
    tess        = groupList [(t, F.expr x) | (x, t) <- F.toListSEnv $ F.lits fi, notFun t]

declSymbols :: F.GInfo c a -> Either E.Error [(F.Symbol, F.Sort)]
declSymbols = fmap dropThy . symbolSorts
  where
    dropThy = filter (not . isThy . fst)
    isThy   = isJust . Thy.smt2Symbol

---------------------------------------------------------------------------
stats :: SolveM Stats
---------------------------------------------------------------------------
stats = ssStats <$> get

---------------------------------------------------------------------------
tickIter :: Bool -> SolveM Int
---------------------------------------------------------------------------
tickIter newScc = progIter newScc >> incIter >> getIter

progIter :: Bool -> SolveM ()
progIter newScc = lift $ when newScc progressTick