module Language.SMTLib2.Internals.Monad where
import Language.SMTLib2.Internals.Backend as B
import Language.SMTLib2.Internals.Type
import Control.Monad.State.Strict
import Control.Exception (onException)
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
#endif
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Foldable (foldlM)
newtype SMT b a = SMT { runSMT :: StateT (SMTState b) (SMTMonad b) a }
data SMTState b = SMTState { backend :: !b
, datatypes :: !(Set String) }
instance Backend b => Functor (SMT b) where
fmap f (SMT act) = SMT (fmap f act)
instance Backend b => Applicative (SMT b) where
pure x = SMT (pure x)
(<*>) (SMT fun) (SMT arg) = SMT (fun <*> arg)
instance Backend b => Monad (SMT b) where
(>>=) (SMT act) app = SMT (act >>= (\res -> case app res of
SMT p -> p))
instance Backend b => MonadState (SMTState b) (SMT b) where
get = SMT get
put x = SMT (put x)
state act = SMT (state act)
instance (Backend b,MonadIO (SMTMonad b)) => MonadIO (SMT b) where
liftIO act = SMT (liftIO act)
withBackend :: Backend b => SMTMonad b b
-> SMT b a
-> SMTMonad b a
withBackend constr act = do
b <- constr
(res,nb) <- runStateT (runSMT act) (SMTState b Set.empty)
exit (backend nb)
return res
withBackendExitCleanly :: (Backend b,SMTMonad b ~ IO) => IO b -> SMT b a -> IO a
withBackendExitCleanly constr (SMT act) = do
b <- constr
(do
(res,nb) <- runStateT act (SMTState b Set.empty)
exit (backend nb)
return res) `onException` (exit b)
liftSMT :: Backend b => SMTMonad b a -> SMT b a
liftSMT act = SMT (lift act)
embedSMT :: Backend b => (b -> SMTMonad b (a,b)) -> SMT b a
embedSMT act = SMT $ do
b <- get
(res,nb) <- lift $ act (backend b)
put (b { backend = nb })
return res
embedSMT' :: Backend b => (b -> SMTMonad b b) -> SMT b ()
embedSMT' act = SMT $ do
b <- get
nb <- lift $ act (backend b)
put (b { backend = nb })
registerDatatype :: (Backend b,IsDatatype dt) => Datatype dt -> SMT b ()
registerDatatype pr = do
st <- get
if Set.member (datatypeName pr) (datatypes st)
then return ()
else do
let (ndts,deps) = dependencies (datatypes st) pr
nb <- foldlM (\b dts -> do
((),nb) <- liftSMT $ B.declareDatatypes dts b
return nb
) (backend st) deps
put $ st { backend = nb
, datatypes = ndts }
defineVar' :: (B.Backend b) => B.Expr b t -> SMT b (B.Var b t)
defineVar' e = embedSMT $ B.defineVar Nothing e
defineVarNamed' :: (B.Backend b) => String -> B.Expr b t -> SMT b (B.Var b t)
defineVarNamed' name e = embedSMT $ B.defineVar (Just name) e
declareVar' :: B.Backend b => Repr t -> SMT b (B.Var b t)
declareVar' tp = embedSMT $ B.declareVar tp Nothing
declareVarNamed' :: B.Backend b => Repr t -> String -> SMT b (B.Var b t)
declareVarNamed' tp name = embedSMT $ B.declareVar tp (Just name)