{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.SAT.Encoder.Cardinality.Internal
-- Copyright   :  (c) Masahiro Sakai 2019
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable
--
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.Cardinality.Internal.ParallelCounter
  ( addAtLeastParallelCounter
  , encodeAtLeastParallelCounter
  ) where

import Control.Monad.Primitive
import Control.Monad.State.Strict
import Data.Bits
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin

addAtLeastParallelCounter :: PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m ()
addAtLeastParallelCounter enc constr = do
  l <- encodeAtLeastParallelCounter enc constr
  SAT.addClause enc [l]

-- TODO: consider polarity
encodeAtLeastParallelCounter :: forall m. PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m SAT.Lit
encodeAtLeastParallelCounter enc (lhs,rhs) = do
  let rhs_bits = bits (fromIntegral rhs)
  (cnt, overflowBits) <- encodeSumParallelCounter enc (length rhs_bits) lhs
  isGE <- encodeGE enc cnt rhs_bits
  Tseitin.encodeDisj enc $ isGE : overflowBits
  where
    bits :: Integer -> [Bool]
    bits n = f n 0
      where
        f 0 !_ = []
        f n i = testBit n i : f (clearBit n i) (i+1)

encodeSumParallelCounter :: forall m. PrimMonad m => Tseitin.Encoder m -> Int -> [SAT.Lit] -> m ([SAT.Lit], [SAT.Lit])
encodeSumParallelCounter enc w lits = do
  let add :: [SAT.Lit] -> [SAT.Lit] -> SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
      add = go 0 []
        where
          go :: Int -> [SAT.Lit] -> [SAT.Lit] -> [SAT.Lit] -> SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
          go i ret _xs _ys c | i == w = do
            modify (c:)
            return $ reverse ret
          go _i ret [] [] c = return $ reverse (c : ret)
          go i ret (x : xs) (y : ys) c = do
            z <- lift $ Tseitin.encodeFASum enc x y c
            c' <- lift $ Tseitin.encodeFACarry enc x y c
            go (i+1) (z : ret) xs ys c'
          go _ _ _ _ _ = error "encodeSumParallelCounter: should not happen"

      f :: Vector SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
      f xs
        | V.null xs = return []
        | otherwise = do
            let len2 = V.length xs `div` 2
            cnt1 <- f (V.slice 0 len2 xs)
            cnt2 <- f (V.slice len2 len2 xs)
            c <- if V.length xs `mod` 2 == 0 then
                   lift $ Tseitin.encodeDisj enc []
                 else
                   lift $ return $ xs V.! (V.length xs - 1)
            add cnt1 cnt2 c

  runStateT (f (V.fromList lits)) []

encodeGE :: forall m. PrimMonad m => Tseitin.Encoder m -> [SAT.Lit] -> [Bool] -> m SAT.Lit
encodeGE enc lhs rhs = do
  let f :: [SAT.Lit] -> [Bool] -> SAT.Lit -> m SAT.Lit
      f [] [] r = return r
      f [] (True  : _) _ = Tseitin.encodeDisj enc [] -- false
      f [] (False : bs) r = f [] bs r
      f (l : ls) (True  : bs) r = do
        f ls bs =<< Tseitin.encodeConj enc [l, r]
      f (l : ls) (False : bs) r = do
        f ls bs =<< Tseitin.encodeDisj enc [l, r]
      f (l : ls) [] r = do
        f ls [] =<< Tseitin.encodeDisj enc [l, r]
  t <- Tseitin.encodeConj enc [] -- true
  f lhs rhs t