{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE CPP #-}
module ToySolver.EUF.EUFSolver
(
Solver
, newSolver
, FSym
, Term (..)
, ConstrID
, VAFun (..)
, newFSym
, newFun
, newConst
, assertEqual
, assertEqual'
, assertNotEqual
, assertNotEqual'
, check
, areEqual
, explain
, Entity
, EntityTuple
, Model (..)
, getModel
, eval
, evalAp
, pushBacktrackPoint
, popBacktrackPoint
, termToFlatTerm
, termToFSym
, fsymToTerm
, fsymToFlatTerm
, flatTermToFSym
) where
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Except
import Data.Either
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.IORef
import qualified ToySolver.Internal.Data.Vec as Vec
import ToySolver.EUF.CongruenceClosure (FSym, Term (..), ConstrID, VAFun (..))
import ToySolver.EUF.CongruenceClosure (Model (..), Entity, EntityTuple, eval, evalAp)
import qualified ToySolver.EUF.CongruenceClosure as CC
data Solver
= Solver
{ Solver -> Solver
svCCSolver :: !CC.Solver
, Solver -> IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities :: IORef (Map (Term, Term) (Maybe ConstrID))
, Solver -> IORef IntSet
svExplanation :: IORef IntSet
, Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints :: !(Vec.Vec (Map (Term, Term) ()))
}
newSolver :: IO Solver
newSolver :: IO Solver
newSolver = do
Solver
cc <- IO Solver
CC.newSolver
IORef (Map (Term, Term) (Maybe ConstrID))
deqs <- Map (Term, Term) (Maybe ConstrID)
-> IO (IORef (Map (Term, Term) (Maybe ConstrID)))
forall a. a -> IO (IORef a)
newIORef Map (Term, Term) (Maybe ConstrID)
forall k a. Map k a
Map.empty
IORef IntSet
expl <- IntSet -> IO (IORef IntSet)
forall a. a -> IO (IORef a)
newIORef IntSet
forall a. HasCallStack => a
undefined
Vec (Map (Term, Term) ())
bp <- IO (Vec (Map (Term, Term) ()))
forall (a :: * -> * -> *) e. MArray a e IO => IO (GenericVec a e)
Vec.new
let solver :: Solver
solver =
Solver :: Solver
-> IORef (Map (Term, Term) (Maybe ConstrID))
-> IORef IntSet
-> Vec (Map (Term, Term) ())
-> Solver
Solver
{ svCCSolver :: Solver
svCCSolver = Solver
cc
, svDisequalities :: IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities = IORef (Map (Term, Term) (Maybe ConstrID))
deqs
, svExplanation :: IORef IntSet
svExplanation = IORef IntSet
expl
, svBacktrackPoints :: Vec (Map (Term, Term) ())
svBacktrackPoints = Vec (Map (Term, Term) ())
bp
}
Solver -> IO Solver
forall (m :: * -> *) a. Monad m => a -> m a
return Solver
solver
newFSym :: Solver -> IO FSym
newFSym :: Solver -> IO ConstrID
newFSym Solver
solver = Solver -> IO ConstrID
CC.newFSym (Solver -> Solver
svCCSolver Solver
solver)
newConst :: Solver -> IO Term
newConst :: Solver -> IO Term
newConst Solver
solver = Solver -> IO Term
CC.newConst (Solver -> Solver
svCCSolver Solver
solver)
newFun :: CC.VAFun a => Solver -> IO a
newFun :: Solver -> IO a
newFun Solver
solver = Solver -> IO a
forall a. VAFun a => Solver -> IO a
CC.newFun (Solver -> Solver
svCCSolver Solver
solver)
assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 Maybe ConstrID
forall a. Maybe a
Nothing
assertEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 Maybe ConstrID
cid = Solver -> Term -> Term -> Maybe ConstrID -> IO ()
CC.merge' (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2 Maybe ConstrID
cid
assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 Maybe ConstrID
forall a. Maybe a
Nothing
assertNotEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 Maybe ConstrID
cid = if Term
t1 Term -> Term -> Bool
forall a. Ord a => a -> a -> Bool
< Term
t2 then (Term, Term) -> Maybe ConstrID -> IO ()
f (Term
t1,Term
t2) Maybe ConstrID
cid else (Term, Term) -> Maybe ConstrID -> IO ()
f (Term
t2,Term
t1) Maybe ConstrID
cid
where
f :: (Term, Term) -> Maybe ConstrID -> IO ()
f (Term, Term)
deq Maybe ConstrID
cid = do
Map (Term, Term) (Maybe ConstrID)
ds <- IORef (Map (Term, Term) (Maybe ConstrID))
-> IO (Map (Term, Term) (Maybe ConstrID))
forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities Solver
solver)
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Term, Term)
deq (Term, Term) -> Map (Term, Term) (Maybe ConstrID) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map (Term, Term) (Maybe ConstrID)
ds) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
ConstrID
_ <- Solver -> Term -> IO ConstrID
termToFSym Solver
solver ((Term, Term) -> Term
forall a b. (a, b) -> a
fst (Term, Term)
deq)
ConstrID
_ <- Solver -> Term -> IO ConstrID
termToFSym Solver
solver ((Term, Term) -> Term
forall a b. (a, b) -> b
snd (Term, Term)
deq)
IORef (Map (Term, Term) (Maybe ConstrID))
-> Map (Term, Term) (Maybe ConstrID) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities Solver
solver) (Map (Term, Term) (Maybe ConstrID) -> IO ())
-> Map (Term, Term) (Maybe ConstrID) -> IO ()
forall a b. (a -> b) -> a -> b
$! (Term, Term)
-> Maybe ConstrID
-> Map (Term, Term) (Maybe ConstrID)
-> Map (Term, Term) (Maybe ConstrID)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq Maybe ConstrID
cid Map (Term, Term) (Maybe ConstrID)
ds
ConstrID
lv <- Solver -> IO ConstrID
getCurrentLevel Solver
solver
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ConstrID
lvConstrID -> ConstrID -> Bool
forall a. Eq a => a -> a -> Bool
==ConstrID
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Vec (Map (Term, Term) ())
-> ConstrID
-> (Map (Term, Term) () -> Map (Term, Term) ())
-> IO ()
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> ConstrID -> (e -> e) -> IO ()
Vec.unsafeModify' (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) (ConstrID
lv ConstrID -> ConstrID -> ConstrID
forall a. Num a => a -> a -> a
- ConstrID
1) ((Map (Term, Term) () -> Map (Term, Term) ()) -> IO ())
-> (Map (Term, Term) () -> Map (Term, Term) ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ (Term, Term) -> () -> Map (Term, Term) () -> Map (Term, Term) ()
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq ()
check :: Solver -> IO Bool
check :: Solver -> IO Bool
check Solver
solver = do
Map (Term, Term) (Maybe ConstrID)
ds <- IORef (Map (Term, Term) (Maybe ConstrID))
-> IO (Map (Term, Term) (Maybe ConstrID))
forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities Solver
solver)
(Either () () -> Bool) -> IO (Either () ()) -> IO Bool
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Either () () -> Bool
forall a b. Either a b -> Bool
isRight (IO (Either () ()) -> IO Bool) -> IO (Either () ()) -> IO Bool
forall a b. (a -> b) -> a -> b
$ ExceptT () IO () -> IO (Either () ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT () IO () -> IO (Either () ()))
-> ExceptT () IO () -> IO (Either () ())
forall a b. (a -> b) -> a -> b
$ [((Term, Term), Maybe ConstrID)]
-> (((Term, Term), Maybe ConstrID) -> ExceptT () IO ())
-> ExceptT () IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map (Term, Term) (Maybe ConstrID)
-> [((Term, Term), Maybe ConstrID)]
forall k a. Map k a -> [(k, a)]
Map.toList Map (Term, Term) (Maybe ConstrID)
ds) ((((Term, Term), Maybe ConstrID) -> ExceptT () IO ())
-> ExceptT () IO ())
-> (((Term, Term), Maybe ConstrID) -> ExceptT () IO ())
-> ExceptT () IO ()
forall a b. (a -> b) -> a -> b
$ \((Term
t1,Term
t2), Maybe ConstrID
cid) -> do
Bool
b <- IO Bool -> ExceptT () IO Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Bool -> ExceptT () IO Bool) -> IO Bool -> ExceptT () IO Bool
forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
if Bool
b then do
Just IntSet
cs <- IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet))
-> IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet)
forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
IO () -> ExceptT () IO ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ExceptT () IO ()) -> IO () -> ExceptT () IO ()
forall a b. (a -> b) -> a -> b
$ IORef IntSet -> IntSet -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef IntSet
svExplanation Solver
solver) (IntSet -> IO ()) -> IntSet -> IO ()
forall a b. (a -> b) -> a -> b
$!
case Maybe ConstrID
cid of
Maybe ConstrID
Nothing -> IntSet
cs
Just ConstrID
c -> ConstrID -> IntSet -> IntSet
IntSet.insert ConstrID
c IntSet
cs
() -> ExceptT () IO ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE ()
else
() -> ExceptT () IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
areEqual :: Solver -> Term -> Term -> IO Bool
areEqual :: Solver -> Term -> Term -> IO Bool
areEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
explain :: Solver -> Maybe (Term,Term) -> IO IntSet
explain :: Solver -> Maybe (Term, Term) -> IO IntSet
explain Solver
solver Maybe (Term, Term)
Nothing = IORef IntSet -> IO IntSet
forall a. IORef a -> IO a
readIORef (Solver -> IORef IntSet
svExplanation Solver
solver)
explain Solver
solver (Just (Term
t1,Term
t2)) = do
Maybe IntSet
ret <- Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
case Maybe IntSet
ret of
Maybe IntSet
Nothing -> [Char] -> IO IntSet
forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.explain: should not happen"
Just IntSet
cs -> IntSet -> IO IntSet
forall (m :: * -> *) a. Monad m => a -> m a
return IntSet
cs
getModel :: Solver -> IO Model
getModel :: Solver -> IO Model
getModel = Solver -> IO Model
CC.getModel (Solver -> IO Model) -> (Solver -> Solver) -> Solver -> IO Model
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
type Level = Int
getCurrentLevel :: Solver -> IO Level
getCurrentLevel :: Solver -> IO ConstrID
getCurrentLevel Solver
solver = Vec (Map (Term, Term) ()) -> IO ConstrID
forall (a :: * -> * -> *) e. GenericVec a e -> IO ConstrID
Vec.getSize (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)
pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint Solver
solver = do
Solver -> IO ()
CC.pushBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
Vec (Map (Term, Term) ()) -> Map (Term, Term) () -> IO ()
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> e -> IO ()
Vec.push (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) Map (Term, Term) ()
forall k a. Map k a
Map.empty
popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint Solver
solver = do
ConstrID
lv <- Solver -> IO ConstrID
getCurrentLevel Solver
solver
if ConstrID
lvConstrID -> ConstrID -> Bool
forall a. Eq a => a -> a -> Bool
==ConstrID
0 then
[Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.popBacktrackPoint: root level"
else do
Solver -> IO ()
CC.popBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
Map (Term, Term) ()
xs <- Vec (Map (Term, Term) ()) -> IO (Map (Term, Term) ())
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> IO e
Vec.unsafePop (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)
IORef (Map (Term, Term) (Maybe ConstrID))
-> (Map (Term, Term) (Maybe ConstrID)
-> Map (Term, Term) (Maybe ConstrID))
-> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Solver -> IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities Solver
solver) ((Map (Term, Term) (Maybe ConstrID)
-> Map (Term, Term) (Maybe ConstrID))
-> IO ())
-> (Map (Term, Term) (Maybe ConstrID)
-> Map (Term, Term) (Maybe ConstrID))
-> IO ()
forall a b. (a -> b) -> a -> b
$ (Map (Term, Term) (Maybe ConstrID)
-> Map (Term, Term) () -> Map (Term, Term) (Maybe ConstrID)
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`Map.difference` Map (Term, Term) ()
xs)
termToFlatTerm :: Solver -> Term -> IO FlatTerm
termToFlatTerm = Solver -> Term -> IO FlatTerm
CC.termToFlatTerm (Solver -> Term -> IO FlatTerm)
-> (Solver -> Solver) -> Solver -> Term -> IO FlatTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
termToFSym :: Solver -> Term -> IO ConstrID
termToFSym = Solver -> Term -> IO ConstrID
CC.termToFSym (Solver -> Term -> IO ConstrID)
-> (Solver -> Solver) -> Solver -> Term -> IO ConstrID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToTerm :: Solver -> ConstrID -> IO Term
fsymToTerm = Solver -> ConstrID -> IO Term
CC.fsymToTerm (Solver -> ConstrID -> IO Term)
-> (Solver -> Solver) -> Solver -> ConstrID -> IO Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToFlatTerm :: Solver -> ConstrID -> IO FlatTerm
fsymToFlatTerm = Solver -> ConstrID -> IO FlatTerm
CC.fsymToFlatTerm (Solver -> ConstrID -> IO FlatTerm)
-> (Solver -> Solver) -> Solver -> ConstrID -> IO FlatTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
flatTermToFSym :: Solver -> FlatTerm -> IO ConstrID
flatTermToFSym = Solver -> FlatTerm -> IO ConstrID
CC.flatTermToFSym (Solver -> FlatTerm -> IO ConstrID)
-> (Solver -> Solver) -> Solver -> FlatTerm -> IO ConstrID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver