{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.BitVector.Solver
-- Copyright   :  (c) Masahiro Sakai 2016
-- License     :  BSD-style
-- 
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  experimental
--
-----------------------------------------------------------------------------
module ToySolver.BitVector.Solver
  (
  -- * BitVector solver
    Solver
  , newSolver
  , newVar
  , newVar'
  , assertAtom
  , check
  , getModel
  , explain
  , pushBacktrackPoint
  , popBacktrackPoint
  ) where

import Prelude hiding (repeat)
import Control.Monad
import qualified Data.Foldable as F
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.IORef
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid
#endif
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Unboxed as VU
import Data.Sequence (Seq)
import qualified Data.Sequence as Seq
import ToySolver.Data.BoolExpr
import ToySolver.Data.Boolean
import ToySolver.Data.OrdRel
import qualified ToySolver.Internal.Data.SeqQueue as SQ
import qualified ToySolver.Internal.Data.Vec as Vec
import qualified ToySolver.SAT as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin

import ToySolver.BitVector.Base

-- ------------------------------------------------------------------------

data Solver
  = Solver
  { svVars :: Vec.Vec (VU.Vector SAT.Lit)
  , svSATSolver :: SAT.Solver
  , svTseitin :: Tseitin.Encoder IO
  , svEncTable :: IORef (Map Expr (VU.Vector SAT.Lit))
  , svDivRemTable :: IORef [(VU.Vector SAT.Lit, VU.Vector SAT.Lit, VU.Vector SAT.Lit, VU.Vector SAT.Lit)]
  , svAtomTable :: IORef (Map NormalizedAtom SAT.Lit)
  , svContexts :: Vec.Vec (IntMap (Maybe Int))
  }

newSolver :: IO Solver
newSolver = do
  vars <- Vec.new
  sat <- SAT.newSolver
  tseitin <- Tseitin.newEncoder sat
  table <- newIORef Map.empty
  divRemTable <- newIORef []
  atomTable <- newIORef Map.empty
  contexts <- Vec.new
  Vec.push contexts IntMap.empty
  return $
    Solver
    { svVars = vars
    , svSATSolver = sat
    , svTseitin = tseitin
    , svEncTable = table
    , svDivRemTable = divRemTable
    , svAtomTable = atomTable
    , svContexts = contexts
    }

newVar :: Solver -> Int -> IO Expr
newVar solver w = EVar <$> newVar' solver w

newVar' :: Solver -> Int -> IO Var
newVar' solver w = do
  bs <- VG.fromList <$> SAT.newVars (svSATSolver solver) w
  v <- Vec.getSize $ svVars solver
  Vec.push (svVars solver) bs
  return $ Var{ varWidth = w, varId = v }

data NormalizedRel = NRSLt | NRULt | NREql
  deriving (Eq, Ord, Enum, Bounded, Show)

data NormalizedAtom = NormalizedAtom NormalizedRel Expr Expr
  deriving (Eq, Ord, Show)

normalizeAtom :: Atom -> (NormalizedAtom, Bool)
normalizeAtom (Rel (OrdRel lhs op rhs) True) =
  case op of
    Lt -> (NormalizedAtom NRSLt lhs rhs, True)
    Gt -> (NormalizedAtom NRSLt rhs lhs, True)
    Le -> (NormalizedAtom NRSLt rhs lhs, False)
    Ge -> (NormalizedAtom NRSLt lhs rhs, False)
    Eql -> (NormalizedAtom NREql lhs rhs, True)
    NEq -> (NormalizedAtom NREql lhs rhs, False)
normalizeAtom (Rel (OrdRel lhs op rhs) False) =
  case op of
    Lt -> (NormalizedAtom NRULt lhs rhs, True)
    Gt -> (NormalizedAtom NRULt rhs lhs, True)
    Le -> (NormalizedAtom NRULt rhs lhs, False)
    Ge -> (NormalizedAtom NRULt lhs rhs, False)
    Eql -> (NormalizedAtom NREql lhs rhs, True)
    NEq -> (NormalizedAtom NREql lhs rhs, False)

assertAtom :: Solver -> Atom -> Maybe Int -> IO ()
assertAtom solver atom label = do
  let (atom'@(NormalizedAtom op lhs rhs), polarity) = normalizeAtom atom
  table <- readIORef (svAtomTable solver)
  l <- (if polarity then id else negate) <$>
    case Map.lookup atom' table of
      Just lit -> return lit
      Nothing -> do
        s <- encodeExpr solver lhs
        t <- encodeExpr solver rhs
        l <- Tseitin.encodeFormula (svTseitin solver) $
          case op of
            NRULt -> isULT s t
            NRSLt -> isSLT s t
            NREql -> isEQ s t
        writeIORef (svAtomTable solver) $ Map.insert atom' l table
        return l
  size <- Vec.getSize (svContexts solver)
  case label of
    Nothing | size == 1 -> SAT.addClause (svTseitin solver) [l]
    _ -> do
      Vec.modify (svContexts solver) (size - 1) (IntMap.insert l label)

check :: Solver -> IO Bool
check solver = do
  size <- Vec.getSize (svContexts solver)
  m <- Vec.read (svContexts solver) (size - 1)
  b <- SAT.solveWith (svSATSolver solver) (IntMap.keys m)
  return b

getModel :: Solver -> IO Model
getModel solver = do
  m <- SAT.getModel (svSATSolver solver)
  vss <- Vec.getElems (svVars solver)
  let f = fromAscBits . map (SAT.evalLit m) . VG.toList
      isZero = not . or . toAscBits
      env = VG.fromList [f vs | vs <- vss]
  xs <- readIORef (svDivRemTable solver)
  let divTable = Map.fromList [(f s, f d) | (s,t,d,_r) <- xs, isZero (f t)]
      remTable = Map.fromList [(f s, f r) | (s,t,_d,r) <- xs, isZero (f t)]
  return (env, divTable, remTable)

explain :: Solver -> IO IntSet
explain solver = do
  xs <- SAT.getFailedAssumptions (svSATSolver solver)
  size <- Vec.getSize (svContexts solver)
  m <- Vec.read (svContexts solver) (size - 1)
  return $ IntSet.fromList $ catMaybes [m IntMap.! x | x <- xs]

pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint solver = do
  size <- Vec.getSize (svContexts solver)
  m <- Vec.read (svContexts solver) (size - 1)
  Vec.push (svContexts solver) m

popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint solver = do
  _ <- Vec.pop (svContexts solver)
  return ()

-- ------------------------------------------------------------------------

type SBV = VU.Vector SAT.Lit

encodeExpr :: Solver -> Expr -> IO SBV
encodeExpr solver = enc
  where
    enc e@(EConst _) = enc' e
    enc e@(EVar _) = enc' e
    enc e = do
      table <- readIORef (svEncTable solver)
      case Map.lookup e table of
        Just vs -> return vs
        Nothing -> do
          vs <- enc' e
          modifyIORef (svEncTable solver) (Map.insert e vs)
          return vs

    enc' (EConst bs) =
      liftM VU.fromList $ forM (toAscBits bs) $ \b ->
        if b
        then Tseitin.encodeConj (svTseitin solver) []
        else Tseitin.encodeDisj (svTseitin solver) []
    enc' (EVar v) = Vec.read (svVars solver) (varId v)
    enc' (EOp1 op arg) = do
      arg' <- enc arg
      case op of
        OpExtract i j -> do
          unless (VG.length arg' > i && i >= j && j >= 0) $
            error ("invalid extract " ++ show (i,j) ++ " on bit-vector of length " ++ show (VG.length arg') ++ " : " ++ show arg)
          return $ VG.slice j (i - j + 1) arg'
        OpNot -> return $ VG.map negate arg'
        OpNeg -> encodeNegate (svTseitin solver) arg'
    enc' (EOp2 op arg1 arg2) = do
      arg1' <- enc arg1
      arg2' <- enc arg2
      case op of
        OpConcat -> return (arg2' <> arg1')
        OpAnd -> VG.zipWithM (\l1 l2 -> Tseitin.encodeConj (svTseitin solver) [l1,l2]) arg1' arg2'
        OpOr  -> VG.zipWithM (\l1 l2 -> Tseitin.encodeDisj (svTseitin solver) [l1,l2]) arg1' arg2'
        OpXOr -> VG.zipWithM (Tseitin.encodeXOR (svTseitin solver)) arg1' arg2'
        OpComp -> VG.singleton <$> Tseitin.encodeFormula (svTseitin solver) (isEQ arg1' arg2')
        OpAdd -> encodeSum (svTseitin solver) (VG.length arg1') True [arg1', arg2']
        OpMul -> encodeMul (svTseitin solver) True arg1' arg2'
        OpUDiv -> fst <$> encodeDivRem solver arg1' arg2'
        OpURem -> snd <$> encodeDivRem solver arg1' arg2'
        OpSDiv -> encodeSDiv solver arg1' arg2'
        OpSRem -> encodeSRem solver arg1' arg2'
        OpSMod -> encodeSMod solver arg1' arg2'
        OpShl  -> encodeShl (svTseitin solver) arg1' arg2'
        OpLShr -> encodeLShr (svTseitin solver) arg1' arg2'
        OpAShr -> encodeAShr (svTseitin solver) arg1' arg2'

encodeMul :: Tseitin.Encoder IO -> Bool -> SBV -> SBV -> IO SBV
encodeMul enc allowOverflow arg1 arg2 = do
  let w = VG.length arg1
  b0 <- Tseitin.encodeDisj enc [] -- False
  bss <- forM (zip [0..] (VG.toList arg2)) $ \(i,b2) -> do
    let arg1' = if allowOverflow
                then VG.take (w - i) arg1
                else arg1
    bs <- VG.forM arg1' $ \b1 -> do
            Tseitin.encodeConj enc [b1,b2]
    return (VG.replicate i b0 <> bs)
  encodeSum enc w allowOverflow bss

encodeSum :: Tseitin.Encoder IO -> Int -> Bool -> [SBV] -> IO SBV
encodeSum enc w allowOverflow xss = do
  (buckets :: IORef (Seq (SQ.SeqQueue IO SAT.Lit))) <- newIORef Seq.empty
  let insert i x = do
        bs <- readIORef buckets
        let n = Seq.length bs
        q <- if i < n then do
               return $ Seq.index bs i
             else do
               qs <- replicateM (i+1 - n) SQ.newFifo
               let bs' = bs Seq.>< Seq.fromList qs
               writeIORef buckets bs'
               return $ Seq.index bs' i
        SQ.enqueue q x

  forM_ xss $ \xs -> do
#if MIN_VERSION_vector(0,11,0)
    VG.imapM insert xs
#else
    VG.mapM (uncurry insert) (VG.indexed xs)
#endif

  let loop i ret
        | i >= w = do
            unless allowOverflow $ do
              bs <- readIORef buckets
              forM_ (F.toList bs) $ \q -> do
                ls <- SQ.dequeueBatch q
                forM_ ls $ \l -> do
                  SAT.addClause  enc [-l]
            return (reverse ret)
        | otherwise = do
            bs <- readIORef buckets
            let n = Seq.length bs
            if i >= n then do
              b <- Tseitin.encodeDisj enc [] -- False
              loop (i+1) (b : ret)
            else do
              let q = Seq.index bs i
              m <- SQ.queueSize q
              case m of
                0 -> do
                  b <- Tseitin.encodeDisj enc [] -- False
                  loop (i+1) (b : ret)
                1 -> do
                  Just b <- SQ.dequeue q
                  loop (i+1) (b : ret)
                2 -> do
                  Just b1 <- SQ.dequeue q
                  Just b2 <- SQ.dequeue q
                  s <- encodeHASum enc b1 b2
                  c <- encodeHACarry enc b1 b2
                  insert (i+1) c
                  loop (i+1) (s : ret)
                _ -> do
                  Just b1 <- SQ.dequeue q
                  Just b2 <- SQ.dequeue q
                  Just b3 <- SQ.dequeue q
                  s <- Tseitin.encodeFASum enc b1 b2 b3
                  c <- Tseitin.encodeFACarry enc b1 b2 b3
                  insert i s
                  insert (i+1) c
                  loop i ret
  VU.fromList <$> loop 0 []

encodeHASum :: Tseitin.Encoder IO -> SAT.Lit -> SAT.Lit -> IO SAT.Lit
encodeHASum = Tseitin.encodeXOR

encodeHACarry :: Tseitin.Encoder IO -> SAT.Lit -> SAT.Lit -> IO SAT.Lit
encodeHACarry enc a b = Tseitin.encodeConj enc [a,b]

encodeNegate :: Tseitin.Encoder IO -> SBV -> IO SBV
encodeNegate enc s = do
  let f _ [] ret = return $ VU.fromList $ reverse ret
      f b (x:xs) ret = do
        y <- Tseitin.encodeITE enc b (- x) x
        b' <- Tseitin.encodeDisj enc [b, x]
        f b' xs (y : ret)
  b0 <- Tseitin.encodeDisj enc []
  f b0 (VG.toList s) []

encodeAbs :: Tseitin.Encoder IO -> SBV -> IO SBV
encodeAbs enc s = do
  let w = VG.length s
  if w == 0 then
    return VG.empty
  else do
    let msb_s = VG.last s
    r <- VG.fromList <$> SAT.newVars enc w
    t <- encodeNegate enc s
    Tseitin.addFormula enc $
      ite (Atom (-msb_s)) (isEQ r s) (isEQ r t)
    return r

encodeShl :: Tseitin.Encoder IO -> SBV -> SBV -> IO SBV
encodeShl enc s t = do
  let w = VG.length s
  when (w /= VG.length t) $ error "invalid width"
  b0 <- Tseitin.encodeDisj enc [] -- False
  let go bs (i,b) =
        VG.generateM w $ \j -> do
          let k = toInteger j - 2^i
              t = if k >= 0 then bs VG.! fromInteger k else b0
              e = bs VG.! j
          Tseitin.encodeITE enc b t e
  foldM go s (zip [(0::Int)..] (VG.toList t))

encodeLShr :: Tseitin.Encoder IO -> SBV -> SBV -> IO SBV
encodeLShr enc s t = do
  let w = VG.length s
  when (w /= VG.length t) $ error "invalid width"
  b0 <- Tseitin.encodeDisj enc [] -- False
  let go bs (i,b) =
        VG.generateM w $ \j -> do
          let k = toInteger j + 2^i
              t = if k < fromIntegral (VG.length bs) then bs VG.! fromInteger k else b0
              e = bs VG.! j
          Tseitin.encodeITE enc b t e
  foldM go s (zip [(0::Int)..] (VG.toList t))

encodeAShr :: Tseitin.Encoder IO -> SBV -> SBV -> IO SBV
encodeAShr enc s t = do
  let w = VG.length s
  when (w /= VG.length t) $ error "invalid width"
  if w == 0 then
    return VG.empty
  else do
    let msb_s = VG.last s
    r <- VG.fromList <$> SAT.newVars enc w
    s' <- encodeNegate enc s
    a <- encodeLShr enc s t
    b <- encodeNegate enc =<< encodeLShr enc s' t
    Tseitin.addFormula enc $
      ite (Atom (-msb_s)) (isEQ r a) (isEQ r b)
    return r

encodeDivRem :: Solver -> SBV -> SBV -> IO (SBV, SBV)
encodeDivRem solver s t = do
  let w = VG.length s
  d <- VG.fromList <$> SAT.newVars (svSATSolver solver) w
  r <- VG.fromList <$> SAT.newVars (svSATSolver solver) w
  c <- do
    tmp <- encodeMul (svTseitin solver) False d t
    encodeSum (svTseitin solver) w False [tmp, r]
  tbl <- readIORef (svDivRemTable solver)
  Tseitin.addFormula (svTseitin solver) $
    ite (isZero t)
        (And [(isEQ s s' .&&. isZero t') .=>. (isEQ d d' .&&. isEQ r r') | (s',t',d',r') <- tbl, w == VG.length s'])
        (isEQ s c .&&. isULT r t)
  modifyIORef (svDivRemTable solver) ((s,t,d,r) :)
  return (d,r)

encodeSDiv :: Solver -> SBV -> SBV -> IO SBV
encodeSDiv solver s t = do
  let w = VG.length s
  when (w /= VG.length t) $ error "invalid width"
  if w == 0 then
    return VG.empty
  else do
    s' <- encodeNegate (svTseitin solver) s
    t' <- encodeNegate (svTseitin solver) t
    let msb_s = VG.last s
        msb_t = VG.last t
    r <- VG.fromList <$> SAT.newVars (svSATSolver solver) w
    let f x y = fst <$> encodeDivRem solver x y
    a <- f s t
    b <- encodeNegate (svTseitin solver) =<< f s' t
    c <- encodeNegate (svTseitin solver) =<< f s t'
    d <- f s' t'
    Tseitin.addFormula (svTseitin solver) $
      ite (Atom (-msb_s) .&&. Atom (-msb_t)) (isEQ r a) $
      ite (Atom msb_s .&&. Atom (-msb_t)) (isEQ r b) $
      ite (Atom (-msb_s) .&&. Atom msb_t) (isEQ r c) $
      (isEQ r d)
    return r

encodeSRem :: Solver -> SBV -> SBV -> IO SBV
encodeSRem solver s t = do
  let w = VG.length s
  when (w /= VG.length t) $ error "invalid width"
  if w == 0 then
    return VG.empty
  else do
    s' <- encodeNegate (svTseitin solver) s
    t' <- encodeNegate (svTseitin solver) t
    let msb_s = VG.last s
        msb_t = VG.last t
    r <- VG.fromList <$> SAT.newVars (svSATSolver solver) w
    let f x y = snd <$> encodeDivRem solver x y
    a <- f s t
    b <- encodeNegate (svTseitin solver) =<< f s' t
    c <- f s t'
    d <- encodeNegate (svTseitin solver) =<< f s' t'
    Tseitin.addFormula (svTseitin solver) $
      ite (Atom (-msb_s) .&&. Atom (-msb_t)) (isEQ r a) $
      ite (Atom msb_s .&&. Atom (-msb_t)) (isEQ r b) $
      ite (Atom (-msb_s) .&&. Atom msb_t) (isEQ r c) $
      (isEQ r d)
    return r

encodeSMod :: Solver -> SBV -> SBV -> IO SBV
encodeSMod solver s t = do
  let w = VG.length s
  when (w /= VG.length t) $ error "invalid width"
  if w == 0 then
    return VG.empty
  else do
    let msb_s = VG.last s
        msb_t = VG.last t
    r <- VG.fromList <$> SAT.newVars (svSATSolver solver) w
    abs_s <- encodeAbs (svTseitin solver) s
    abs_t <- encodeAbs (svTseitin solver) t
    u <- snd <$> encodeDivRem solver abs_s abs_t
    u' <- encodeNegate (svTseitin solver) u
    a <- encodeSum (svTseitin solver) w True [u', t]
    b <- encodeSum (svTseitin solver) w True [u, t]
    Tseitin.addFormula (svTseitin solver) $
      ite (isZero u .||. (Atom (-msb_s) .&&. Atom (-msb_t))) (isEQ r u) $
      ite (Atom msb_s .&&. Atom (-msb_t)) (isEQ r a) $
      ite (Atom (-msb_s) .&&. Atom msb_t) (isEQ r b) $
      (isEQ r u')
    return r

isZero :: SBV -> Tseitin.Formula
isZero bs = And [Not (Atom b) | b <- VG.toList bs]

isEQ :: SBV -> SBV -> Tseitin.Formula
isEQ bs1 bs2
  | VG.length bs1 /= VG.length bs2 = error ("length mismatch: " ++ show (VG.length bs1) ++ " and " ++ show (VG.length bs2))
  | otherwise = And [Equiv (Atom b1) (Atom b2) | (b1,b2) <- zip (VG.toList bs1) (VG.toList bs2)]

isULT :: SBV -> SBV -> Tseitin.Formula
isULT bs1 bs2
  | VG.length bs1 /= VG.length bs2 = error ("length mismatch: " ++ show (VG.length bs1) ++ " and " ++ show (VG.length bs2))
  | otherwise = f (VG.toList (VG.reverse bs1)) (VG.toList (VG.reverse bs2))
  where
    f [] [] = false
    f (b1:bs1) (b2:bs2) =
      (notB (Atom b1) .&&. Atom b2) .||. ((Atom b1 .=>. Atom b2) .&&. f bs1 bs2)
    f _ _ = error "should not happen"

isSLT :: SBV -> SBV -> Tseitin.Formula
isSLT bs1 bs2
  | VG.length bs1 /= VG.length bs2 = error ("length mismatch: " ++ show (VG.length bs1) ++ " and " ++ show (VG.length bs2))
  | w == 0 = false
  | otherwise =
      Atom bs1_msb .&&. Not (Atom bs2_msb)
      .||. (Atom bs1_msb .<=>. Atom bs2_msb) .&&. isULT bs1 bs2
  where
    w = VG.length bs1
    bs1_msb = bs1 VG.! (w-1)
    bs2_msb = bs2 VG.! (w-1)

-- ------------------------------------------------------------------------

_test1 :: IO ()
_test1 = do
  solver <- newSolver
  v1 <- newVar solver 8
  v2 <- newVar solver 8
  assertAtom solver (EOp2 OpMul v1 v2 .==. nat2bv 8 6) Nothing
  print =<< check solver
  m <- getModel solver
  print m

_test2 :: IO ()
_test2 = do
  solver <- newSolver
  v1 <- newVar solver 8
  v2 <- newVar solver 8
  let z = nat2bv 8 0
  assertAtom solver (EOp2 OpUDiv v1 z ./=. EOp2 OpUDiv v2 z) Nothing
  assertAtom solver (v1 .==. v2) Nothing
  print =<< check solver -- False