module ToySolver.SAT.TseitinEncoder
(
Encoder
, newEncoder
, setUsePB
, encSolver
, Polarity (..)
, negatePolarity
, polarityPos
, polarityNeg
, polarityBoth
, polarityNone
, Formula
, evalFormula
, addFormula
, encodeConj
, encodeConjWithPolarity
, encodeDisj
, encodeDisjWithPolarity
, encodeITE
, encodeITEWithPolarity
, getDefinitions
) where
import Control.Monad
import Data.IORef
import Data.Map (Map)
import qualified Data.Map as Map
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import ToySolver.Data.Boolean
import ToySolver.Data.BoolExpr
import qualified ToySolver.SAT as SAT
import qualified ToySolver.SAT.Types as SAT
type Formula = BoolExpr SAT.Lit
evalFormula :: SAT.IModel m => m -> Formula -> Bool
evalFormula m = fold (SAT.evalLit m)
data Encoder =
Encoder
{ encSolver :: SAT.Solver
, encUsePB :: IORef Bool
, encConjTable :: !(IORef (Map SAT.LitSet (SAT.Lit, Bool, Bool)))
, encITETable :: !(IORef (Map (SAT.Lit, SAT.Lit, SAT.Lit) (SAT.Lit, Bool, Bool)))
}
newEncoder :: SAT.Solver -> IO Encoder
newEncoder solver = do
usePBRef <- newIORef False
table <- newIORef Map.empty
table2 <- newIORef Map.empty
return $
Encoder
{ encSolver = solver
, encUsePB = usePBRef
, encConjTable = table
, encITETable = table2
}
setUsePB :: Encoder -> Bool -> IO ()
setUsePB encoder usePB = writeIORef (encUsePB encoder) usePB
addFormula :: Encoder -> Formula -> IO ()
addFormula encoder formula = do
let solver = encSolver encoder
case formula of
And xs -> mapM_ (addFormula encoder) xs
Equiv a b -> do
lit1 <- encodeToLit encoder a
lit2 <- encodeToLit encoder b
SAT.addClause solver [SAT.litNot lit1, lit2]
SAT.addClause solver [SAT.litNot lit2, lit1]
Not (Not a) -> addFormula encoder a
Not (Or xs) -> addFormula encoder (andB (map notB xs))
Not (Imply a b) -> addFormula encoder (a .&&. notB b)
Not (Equiv a b) -> do
lit1 <- encodeToLit encoder a
lit2 <- encodeToLit encoder b
SAT.addClause solver [lit1, lit2]
SAT.addClause solver [SAT.litNot lit1, SAT.litNot lit2]
_ -> do
c <- encodeToClause encoder formula
SAT.addClause solver c
encodeToClause :: Encoder -> Formula -> IO SAT.Clause
encodeToClause encoder formula =
case formula of
And [x] -> encodeToClause encoder x
Or xs -> do
cs <- mapM (encodeToClause encoder) xs
return $ concat cs
Not (Not x) -> encodeToClause encoder x
Not (And xs) -> do
encodeToClause encoder (orB (map notB xs))
Imply a b -> do
encodeToClause encoder (notB a .||. b)
_ -> do
l <- encodeToLitWithPolarity encoder polarityPos formula
return [l]
encodeToLit :: Encoder -> Formula -> IO SAT.Lit
encodeToLit encoder = encodeToLitWithPolarity encoder polarityBoth
encodeToLitWithPolarity :: Encoder -> Polarity -> Formula -> IO SAT.Lit
encodeToLitWithPolarity encoder p formula = do
case formula of
Atom l -> return l
And xs -> encodeConjWithPolarity encoder p =<< mapM (encodeToLitWithPolarity encoder p) xs
Or xs -> encodeDisjWithPolarity encoder p =<< mapM (encodeToLitWithPolarity encoder p) xs
Not x -> liftM SAT.litNot $ encodeToLitWithPolarity encoder (negatePolarity p) x
Imply x y -> do
encodeToLitWithPolarity encoder p (notB x .||. y)
Equiv x y -> do
lit1 <- encodeToLitWithPolarity encoder polarityBoth x
lit2 <- encodeToLitWithPolarity encoder polarityBoth y
encodeToLitWithPolarity encoder p $
(Atom lit1 .=>. Atom lit2) .&&. (Atom lit2 .=>. Atom lit1)
ITE c t e -> do
c' <- encodeToLitWithPolarity encoder polarityBoth c
t' <- encodeToLitWithPolarity encoder p t
e' <- encodeToLitWithPolarity encoder p e
encodeITEWithPolarity encoder p c' t' e'
encodeConj :: Encoder -> [SAT.Lit] -> IO SAT.Lit
encodeConj encoder = encodeConjWithPolarity encoder polarityBoth
encodeConjWithPolarity :: Encoder -> Polarity -> [SAT.Lit] -> IO SAT.Lit
encodeConjWithPolarity _ _ [l] = return l
encodeConjWithPolarity encoder (Polarity pos neg) ls = do
let ls2 = IntSet.fromList ls
let solver = encSolver encoder
usePB <- readIORef (encUsePB encoder)
table <- readIORef (encConjTable encoder)
let
definePos :: SAT.Lit -> IO ()
definePos l = do
if usePB then do
let n = IntSet.size ls2
SAT.addPBAtLeast solver (( fromIntegral n, l) : [(1,li) | li <- IntSet.toList ls2]) 0
else do
forM_ (IntSet.toList ls2) $ \li -> do
SAT.addClause solver [SAT.litNot l, li]
defineNeg :: SAT.Lit -> IO ()
defineNeg l = do
let solver = encSolver encoder
SAT.addClause solver (l : map SAT.litNot (IntSet.toList ls2))
case Map.lookup ls2 table of
Just (l, posDefined, negDefined) -> do
when (pos && not posDefined) $ definePos l
when (neg && not negDefined) $ defineNeg l
when (posDefined < pos || negDefined < neg) $
modifyIORef (encConjTable encoder) (Map.insert ls2 (l, (max posDefined pos), (max negDefined neg)))
return l
Nothing -> do
let sat = encSolver encoder
l <- SAT.newVar sat
when pos $ definePos l
when neg $ defineNeg l
modifyIORef (encConjTable encoder) (Map.insert ls2 (l, pos, neg))
return l
encodeDisj :: Encoder -> [SAT.Lit] -> IO SAT.Lit
encodeDisj encoder = encodeDisjWithPolarity encoder polarityBoth
encodeDisjWithPolarity :: Encoder -> Polarity -> [SAT.Lit] -> IO SAT.Lit
encodeDisjWithPolarity _ _ [l] = return l
encodeDisjWithPolarity encoder p ls = do
l <- encodeConjWithPolarity encoder (negatePolarity p) [SAT.litNot li | li <- ls]
return $ SAT.litNot l
encodeITE :: Encoder -> SAT.Lit -> SAT.Lit -> SAT.Lit -> IO SAT.Lit
encodeITE encoder = encodeITEWithPolarity encoder polarityBoth
encodeITEWithPolarity :: Encoder -> Polarity -> SAT.Lit -> SAT.Lit -> SAT.Lit -> IO SAT.Lit
encodeITEWithPolarity encoder p c t e | c < 0 =
encodeITEWithPolarity encoder p ( c) e t
encodeITEWithPolarity encoder (Polarity pos neg) c t e = do
let solver = encSolver encoder
table <- readIORef (encITETable encoder)
let definePos :: SAT.Lit -> IO ()
definePos x = do
SAT.addClause solver [x, c, t]
SAT.addClause solver [x, c, e]
SAT.addClause solver [t, e, x]
defineNeg :: SAT.Lit -> IO ()
defineNeg x = do
SAT.addClause solver [c, t, x]
SAT.addClause solver [c, e, x]
SAT.addClause solver [t, e, x]
case Map.lookup (c,t,e) table of
Just (l, posDefined, negDefined) -> do
when (pos && not posDefined) $ definePos l
when (neg && not negDefined) $ defineNeg l
when (posDefined < pos || negDefined < neg) $
modifyIORef (encITETable encoder) (Map.insert (c,t,e) (l, (max posDefined pos), (max negDefined neg)))
return l
Nothing -> do
l <- SAT.newVar solver
when pos $ definePos l
when neg $ defineNeg l
modifyIORef (encITETable encoder) (Map.insert (c,t,e) (l, pos, neg))
return l
getDefinitions :: Encoder -> IO [(SAT.Lit, Formula)]
getDefinitions encoder = do
t <- readIORef (encConjTable encoder)
return $ [(l, andB [Atom l1 | l1 <- IntSet.toList ls]) | (ls, (l, _, _)) <- Map.toList t]
data Polarity
= Polarity
{ polarityPosOccurs :: Bool
, polarityNegOccurs :: Bool
}
deriving (Eq, Show)
negatePolarity :: Polarity -> Polarity
negatePolarity (Polarity pos neg) = (Polarity neg pos)
polarityPos :: Polarity
polarityPos = Polarity True False
polarityNeg :: Polarity
polarityNeg = Polarity False True
polarityBoth :: Polarity
polarityBoth = Polarity True True
polarityNone :: Polarity
polarityNone = Polarity False False