{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.SAT.Encoder.Cardinality.Internal.ParallelCounter
-- Copyright   :  (c) Masahiro Sakai 2019
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable
--
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.Cardinality.Internal.ParallelCounter
  ( addAtLeastParallelCounter
  , encodeAtLeastParallelCounter
  ) where

import Control.Monad.Primitive
import Control.Monad.State.Strict
import Data.Bits
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin

addAtLeastParallelCounter :: PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m ()
addAtLeastParallelCounter :: Encoder m -> AtLeast -> m ()
addAtLeastParallelCounter Encoder m
enc AtLeast
constr = do
  Lit
l <- Encoder m -> AtLeast -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> AtLeast -> m Lit
encodeAtLeastParallelCounter Encoder m
enc AtLeast
constr
  Encoder m -> Clause -> m ()
forall (m :: * -> *) a. AddClause m a => a -> Clause -> m ()
SAT.addClause Encoder m
enc [Lit
l]

-- TODO: consider polarity
encodeAtLeastParallelCounter :: forall m. PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m SAT.Lit
encodeAtLeastParallelCounter :: Encoder m -> AtLeast -> m Lit
encodeAtLeastParallelCounter Encoder m
enc (Clause
lhs,Lit
rhs) = do
  if Lit
rhs Lit -> Lit -> Bool
forall a. Ord a => a -> a -> Bool
<= Lit
0 then
    Encoder m -> Clause -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeConj Encoder m
enc []
  else if Clause -> Lit
forall (t :: * -> *) a. Foldable t => t a -> Lit
length Clause
lhs Lit -> Lit -> Bool
forall a. Ord a => a -> a -> Bool
< Lit
rhs then
    Encoder m -> Clause -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc []
  else do
    let rhs_bits :: [Bool]
rhs_bits = Integer -> [Bool]
bits (Lit -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Lit
rhs)
    (Clause
cnt, Clause
overflowBits) <- Encoder m -> Lit -> Clause -> m (Clause, Clause)
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Clause -> m (Clause, Clause)
encodeSumParallelCounter Encoder m
enc ([Bool] -> Lit
forall (t :: * -> *) a. Foldable t => t a -> Lit
length [Bool]
rhs_bits) Clause
lhs
    Lit
isGE <- Encoder m -> Clause -> [Bool] -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Clause -> [Bool] -> m Lit
encodeGE Encoder m
enc Clause
cnt [Bool]
rhs_bits
    Encoder m -> Clause -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc (Clause -> m Lit) -> Clause -> m Lit
forall a b. (a -> b) -> a -> b
$ Lit
isGE Lit -> Clause -> Clause
forall a. a -> [a] -> [a]
: Clause
overflowBits
  where
    bits :: Integer -> [Bool]
    bits :: Integer -> [Bool]
bits Integer
n = Integer -> Lit -> [Bool]
forall t. (Num t, Bits t) => t -> Lit -> [Bool]
f Integer
n Lit
0
      where
        f :: t -> Lit -> [Bool]
f t
0 !Lit
_ = []
        f t
n Lit
i = t -> Lit -> Bool
forall a. Bits a => a -> Lit -> Bool
testBit t
n Lit
i Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
: t -> Lit -> [Bool]
f (t -> Lit -> t
forall a. Bits a => a -> Lit -> a
clearBit t
n Lit
i) (Lit
iLit -> Lit -> Lit
forall a. Num a => a -> a -> a
+Lit
1)

encodeSumParallelCounter :: forall m. PrimMonad m => Tseitin.Encoder m -> Int -> [SAT.Lit] -> m ([SAT.Lit], [SAT.Lit])
encodeSumParallelCounter :: Encoder m -> Lit -> Clause -> m (Clause, Clause)
encodeSumParallelCounter Encoder m
enc Lit
w Clause
lits = do
  let add :: [SAT.Lit] -> [SAT.Lit] -> SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
      add :: Clause -> Clause -> Lit -> StateT Clause m Clause
add = Lit -> Clause -> Clause -> Clause -> Lit -> StateT Clause m Clause
go Lit
0 []
        where
          go :: Int -> [SAT.Lit] -> [SAT.Lit] -> [SAT.Lit] -> SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
          go :: Lit -> Clause -> Clause -> Clause -> Lit -> StateT Clause m Clause
go Lit
i Clause
ret Clause
_xs Clause
_ys Lit
c | Lit
i Lit -> Lit -> Bool
forall a. Eq a => a -> a -> Bool
== Lit
w = do
            (Clause -> Clause) -> StateT Clause m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Lit
cLit -> Clause -> Clause
forall a. a -> [a] -> [a]
:)
            Clause -> StateT Clause m Clause
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> StateT Clause m Clause)
-> Clause -> StateT Clause m Clause
forall a b. (a -> b) -> a -> b
$ Clause -> Clause
forall a. [a] -> [a]
reverse Clause
ret
          go Lit
_i Clause
ret [] [] Lit
c = Clause -> StateT Clause m Clause
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> StateT Clause m Clause)
-> Clause -> StateT Clause m Clause
forall a b. (a -> b) -> a -> b
$ Clause -> Clause
forall a. [a] -> [a]
reverse (Lit
c Lit -> Clause -> Clause
forall a. a -> [a] -> [a]
: Clause
ret)
          go Lit
i Clause
ret (Lit
x : Clause
xs) (Lit
y : Clause
ys) Lit
c = do
            Lit
z <- m Lit -> StateT Clause m Lit
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT Clause m Lit) -> m Lit -> StateT Clause m Lit
forall a b. (a -> b) -> a -> b
$ Encoder m -> Lit -> Lit -> Lit -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Lit -> Lit -> m Lit
Tseitin.encodeFASum Encoder m
enc Lit
x Lit
y Lit
c
            Lit
c' <- m Lit -> StateT Clause m Lit
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT Clause m Lit) -> m Lit -> StateT Clause m Lit
forall a b. (a -> b) -> a -> b
$ Encoder m -> Lit -> Lit -> Lit -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Lit -> Lit -> m Lit
Tseitin.encodeFACarry Encoder m
enc Lit
x Lit
y Lit
c
            Lit -> Clause -> Clause -> Clause -> Lit -> StateT Clause m Clause
go (Lit
iLit -> Lit -> Lit
forall a. Num a => a -> a -> a
+Lit
1) (Lit
z Lit -> Clause -> Clause
forall a. a -> [a] -> [a]
: Clause
ret) Clause
xs Clause
ys Lit
c'
          go Lit
_ Clause
_ Clause
_ Clause
_ Lit
_ = [Char] -> StateT Clause m Clause
forall a. HasCallStack => [Char] -> a
error [Char]
"encodeSumParallelCounter: should not happen"

      f :: Vector SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
      f :: Vector Lit -> StateT Clause m Clause
f Vector Lit
xs
        | Vector Lit -> Bool
forall a. Vector a -> Bool
V.null Vector Lit
xs = Clause -> StateT Clause m Clause
forall (m :: * -> *) a. Monad m => a -> m a
return []
        | Bool
otherwise = do
            let len2 :: Lit
len2 = Vector Lit -> Lit
forall a. Vector a -> Lit
V.length Vector Lit
xs Lit -> Lit -> Lit
forall a. Integral a => a -> a -> a
`div` Lit
2
            Clause
cnt1 <- Vector Lit -> StateT Clause m Clause
f (Lit -> Lit -> Vector Lit -> Vector Lit
forall a. Lit -> Lit -> Vector a -> Vector a
V.slice Lit
0 Lit
len2 Vector Lit
xs)
            Clause
cnt2 <- Vector Lit -> StateT Clause m Clause
f (Lit -> Lit -> Vector Lit -> Vector Lit
forall a. Lit -> Lit -> Vector a -> Vector a
V.slice Lit
len2 Lit
len2 Vector Lit
xs)
            Lit
c <- if Vector Lit -> Lit
forall a. Vector a -> Lit
V.length Vector Lit
xs Lit -> Lit -> Lit
forall a. Integral a => a -> a -> a
`mod` Lit
2 Lit -> Lit -> Bool
forall a. Eq a => a -> a -> Bool
== Lit
0 then
                   m Lit -> StateT Clause m Lit
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT Clause m Lit) -> m Lit -> StateT Clause m Lit
forall a b. (a -> b) -> a -> b
$ Encoder m -> Clause -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc []
                 else
                   m Lit -> StateT Clause m Lit
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT Clause m Lit) -> m Lit -> StateT Clause m Lit
forall a b. (a -> b) -> a -> b
$ Lit -> m Lit
forall (m :: * -> *) a. Monad m => a -> m a
return (Lit -> m Lit) -> Lit -> m Lit
forall a b. (a -> b) -> a -> b
$ Vector Lit
xs Vector Lit -> Lit -> Lit
forall a. Vector a -> Lit -> a
V.! (Vector Lit -> Lit
forall a. Vector a -> Lit
V.length Vector Lit
xs Lit -> Lit -> Lit
forall a. Num a => a -> a -> a
- Lit
1)
            Clause -> Clause -> Lit -> StateT Clause m Clause
add Clause
cnt1 Clause
cnt2 Lit
c

  StateT Clause m Clause -> Clause -> m (Clause, Clause)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Vector Lit -> StateT Clause m Clause
f (Clause -> Vector Lit
forall a. [a] -> Vector a
V.fromList Clause
lits)) []

encodeGE :: forall m. PrimMonad m => Tseitin.Encoder m -> [SAT.Lit] -> [Bool] -> m SAT.Lit
encodeGE :: Encoder m -> Clause -> [Bool] -> m Lit
encodeGE Encoder m
enc Clause
lhs [Bool]
rhs = do
  let f :: [SAT.Lit] -> [Bool] -> SAT.Lit -> m SAT.Lit
      f :: Clause -> [Bool] -> Lit -> m Lit
f [] [] Lit
r = Lit -> m Lit
forall (m :: * -> *) a. Monad m => a -> m a
return Lit
r
      f [] (Bool
True  : [Bool]
_) Lit
_ = Encoder m -> Clause -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc [] -- false
      f [] (Bool
False : [Bool]
bs) Lit
r = Clause -> [Bool] -> Lit -> m Lit
f [] [Bool]
bs Lit
r
      f (Lit
l : Clause
ls) (Bool
True  : [Bool]
bs) Lit
r = do
        Clause -> [Bool] -> Lit -> m Lit
f Clause
ls [Bool]
bs (Lit -> m Lit) -> m Lit -> m Lit
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Encoder m -> Clause -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeConj Encoder m
enc [Lit
l, Lit
r]
      f (Lit
l : Clause
ls) (Bool
False : [Bool]
bs) Lit
r = do
        Clause -> [Bool] -> Lit -> m Lit
f Clause
ls [Bool]
bs (Lit -> m Lit) -> m Lit -> m Lit
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Encoder m -> Clause -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc [Lit
l, Lit
r]
      f (Lit
l : Clause
ls) [] Lit
r = do
        Clause -> [Bool] -> Lit -> m Lit
f Clause
ls [] (Lit -> m Lit) -> m Lit -> m Lit
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Encoder m -> Clause -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc [Lit
l, Lit
r]
  Lit
t <- Encoder m -> Clause -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeConj Encoder m
enc [] -- true
  Clause -> [Bool] -> Lit -> m Lit
f Clause
lhs [Bool]
rhs Lit
t