module ToySolver.SAT.PBO.BCD2
( Options (..)
, defaultOptions
, solve
) where
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import qualified Data.IntSet as IntSet
import qualified Data.IntMap as IntMap
import qualified ToySolver.SAT as SAT
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.PBO.Context as C
import qualified ToySolver.Knapsack as Knapsack
import Text.Printf
data Options
= Options
{ optEnableHardening :: Bool
, optEnableBiasedSearch :: Bool
, optSolvingNormalFirst :: Bool
}
defaultOptions :: Options
defaultOptions
= Options
{ optEnableHardening = True
, optEnableBiasedSearch = True
, optSolvingNormalFirst = True
}
data CoreInfo
= CoreInfo
{ coreLits :: SAT.LitSet
, coreLB :: !Integer
}
solve :: C.Context cxt => cxt -> SAT.Solver -> Options -> IO ()
solve cxt solver opt = solveWBO (C.normalize cxt) solver opt
solveWBO :: C.Context cxt => cxt -> SAT.Solver -> Options -> IO ()
solveWBO cxt solver opt = do
SAT.setEnableBackwardSubsumptionRemoval solver True
let unrelaxed = IntSet.fromList [lit | (lit,_) <- sels]
relaxed = IntSet.empty
hardened = IntSet.empty
cnt = (1,1)
best <- atomically $ C.getBestModel cxt
case best of
Just m -> do
loop (unrelaxed, relaxed, hardened) weights [] (SAT.evalPBLinSum m obj 1) (Just m) cnt
Nothing
| optSolvingNormalFirst opt -> do
ret <- SAT.solve solver
if ret then do
m <- SAT.model solver
let val = SAT.evalPBLinSum m obj
let ub' = val 1
C.logMessage cxt $ printf "BCD2: updating upper bound: %d -> %d" (SAT.pbUpperBound obj) ub'
C.addSolution cxt m
SAT.addPBAtMost solver obj ub'
loop (unrelaxed, relaxed, hardened) weights [] ub' (Just m) cnt
else
C.setFinished cxt
| otherwise -> do
loop (unrelaxed, relaxed, hardened) weights [] (SAT.pbUpperBound obj) Nothing cnt
where
obj :: SAT.PBLinSum
obj = C.getObjectiveFunction cxt
sels :: [(SAT.Lit, Integer)]
sels = [(lit, w) | (w,lit) <- obj]
weights :: SAT.LitMap Integer
weights = IntMap.fromList sels
coreCostFun :: CoreInfo -> SAT.PBLinSum
coreCostFun c = [(weights IntMap.! lit, lit) | lit <- IntSet.toList (coreLits c)]
computeLB :: [CoreInfo] -> Integer
computeLB cores = sum [coreLB info | info <- cores]
loop :: (SAT.LitSet, SAT.LitSet, SAT.LitSet) -> SAT.LitMap Integer -> [CoreInfo] -> Integer -> Maybe SAT.Model -> (Integer,Integer) -> IO ()
loop (unrelaxed, relaxed, hardened) deductedWeight cores ub lastModel (!nsat,!nunsat) = do
let lb = computeLB cores
C.logMessage cxt $ printf "BCD2: %d <= obj <= %d" lb ub
C.logMessage cxt $ printf "BCD2: #cores=%d, #unrelaxed=%d, #relaxed=%d, #hardened=%d"
(length cores) (IntSet.size unrelaxed) (IntSet.size relaxed) (IntSet.size hardened)
when (optEnableBiasedSearch opt) $ do
C.logMessage cxt $ printf "BCD2: bias = %d/%d" nunsat (nunsat + nsat)
sels <- liftM IntMap.fromList $ forM cores $ \info -> do
sel <- SAT.newVar solver
let ep = case lastModel of
Nothing -> sum [weights IntMap.! lit | lit <- IntSet.toList (coreLits info)]
Just m -> SAT.evalPBLinSum m (coreCostFun info)
mid
| optEnableBiasedSearch opt = coreLB info + (ep coreLB info) * nunsat `div` (nunsat + nsat)
| otherwise = (coreLB info + ep) `div` 2
SAT.addPBAtMostSoft solver sel (coreCostFun info) mid
return (sel, (info,mid))
ret <- SAT.solveWith solver (IntMap.keys sels ++ IntSet.toList unrelaxed)
if ret then do
m <- SAT.model solver
let val = SAT.evalPBLinSum m obj
let ub' = val 1
C.logMessage cxt $ printf "BCD2: updating upper bound: %d -> %d" ub ub'
C.addSolution cxt m
SAT.addPBAtMost solver obj ub'
cont (unrelaxed, relaxed, hardened) deductedWeight cores ub' (Just m) (nsat+1,nunsat)
else do
core <- SAT.failedAssumptions solver
case core of
[] -> C.setFinished cxt
[sel] | Just (info,mid) <- IntMap.lookup sel sels -> do
let newLB = refine [weights IntMap.! lit | lit <- IntSet.toList (coreLits info)] mid
info' = info{ coreLB = newLB }
cores' = IntMap.elems $ IntMap.insert sel info' $ IntMap.map fst sels
lb' = computeLB cores'
deductedWeight' = IntMap.unionWith (+) deductedWeight (IntMap.fromList [(lit, d) | let d = lb' lb, d /= 0, lit <- IntSet.toList (coreLits info)])
C.logMessage cxt $ printf "BCD2: updating lower bound of a core"
SAT.addPBAtLeast solver (coreCostFun info') (coreLB info')
cont (unrelaxed, relaxed, hardened) deductedWeight' cores' ub lastModel (nsat,nunsat+1)
_ -> do
let coreSet = IntSet.fromList core
torelax = unrelaxed `IntSet.intersection` coreSet
unrelaxed' = unrelaxed `IntSet.difference` torelax
relaxed' = relaxed `IntSet.union` torelax
intersected = [(info,mid) | (sel,(info,mid)) <- IntMap.toList sels, sel `IntSet.member` coreSet]
rest = [info | (sel,(info,_)) <- IntMap.toList sels, sel `IntSet.notMember` coreSet]
delta = minimum $ [mid coreLB info + 1 | (info,mid) <- intersected] ++
[weights IntMap.! lit | lit <- IntSet.toList torelax]
newLits = IntSet.unions $ torelax : [coreLits info | (info,_) <- intersected]
mergedCore = CoreInfo
{ coreLits = newLits
, coreLB = refine [weights IntMap.! lit | lit <- IntSet.toList relaxed'] (sum [coreLB info | (info,_) <- intersected] + delta 1)
}
cores' = mergedCore : rest
lb' = computeLB cores'
deductedWeight' = IntMap.unionWith (+) deductedWeight (IntMap.fromList [(lit, d) | let d = lb' lb, d /= 0, lit <- IntSet.toList newLits])
if null intersected then do
C.logMessage cxt $ printf "BCD2: found a new core of size %d" (IntSet.size torelax)
else do
C.logMessage cxt $ printf "BCD2: merging cores"
SAT.addPBAtLeast solver (coreCostFun mergedCore) (coreLB mergedCore)
forM_ (IntMap.keys sels) $ \sel -> SAT.addClause solver [sel]
cont (unrelaxed', relaxed', hardened) deductedWeight' cores' ub lastModel (nsat,nunsat+1)
cont :: (SAT.LitSet, SAT.LitSet, SAT.LitSet) -> SAT.LitMap Integer -> [CoreInfo] -> Integer -> Maybe SAT.Model -> (Integer,Integer) -> IO ()
cont (unrelaxed, relaxed, hardened) deductedWeight cores ub lastModel (!nsat,!nunsat)
| lb > ub = C.setFinished cxt
| optEnableHardening opt = do
let lits = IntMap.keysSet $ IntMap.filter (\w -> lb + w > ub) deductedWeight
forM_ (IntSet.toList lits) $ \lit -> SAT.addClause solver [lit]
let unrelaxed' = unrelaxed `IntSet.difference` lits
relaxed' = relaxed `IntSet.difference` lits
hardened' = hardened `IntSet.union` lits
cores' = map (\core -> core{ coreLits = coreLits core `IntSet.difference` lits }) cores
loop (unrelaxed', relaxed', hardened') deductedWeight cores' ub lastModel (nsat,nunsat)
| otherwise =
loop (unrelaxed, relaxed, hardened) deductedWeight cores ub lastModel (nsat,nunsat)
where
lb = computeLB cores
refine
:: [Integer]
-> Integer
-> Integer
refine ws n = n+1