{-# OPTIONS_GHC -Wall #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.Arith.DifferenceLogic
-- Copyright   :  (c) Masahiro Sakai 2016
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Reference:
--
-- * Albert Oliveras and Enric Rodriguez-Carbonell.
--   “General overview of a T-Solver for Difference Logic”.
--   <https://www.cs.upc.edu/~oliveras/TDV/dl.pdf>
--
-----------------------------------------------------------------------------
module ToySolver.Arith.DifferenceLogic
  ( SimpleAtom (..)
  , Var
  , Diff (..)
  , solve
  ) where

import Data.Hashable
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap
import Data.HashSet (HashSet)
import qualified Data.HashSet as HashSet
import Data.Monoid

import ToySolver.Graph.ShortestPath (bellmanFord, lastInEdge, bellmanFordDetectNegativeCycle, monoid')

infixl 6 :-
infix 4 :<=

type Var = Int

-- | Difference of two variables
data Diff = Var :- Var
  deriving (Diff -> Diff -> Bool
(Diff -> Diff -> Bool) -> (Diff -> Diff -> Bool) -> Eq Diff
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Diff -> Diff -> Bool
$c/= :: Diff -> Diff -> Bool
== :: Diff -> Diff -> Bool
$c== :: Diff -> Diff -> Bool
Eq, Eq Diff
Eq Diff
-> (Diff -> Diff -> Ordering)
-> (Diff -> Diff -> Bool)
-> (Diff -> Diff -> Bool)
-> (Diff -> Diff -> Bool)
-> (Diff -> Diff -> Bool)
-> (Diff -> Diff -> Diff)
-> (Diff -> Diff -> Diff)
-> Ord Diff
Diff -> Diff -> Bool
Diff -> Diff -> Ordering
Diff -> Diff -> Diff
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Diff -> Diff -> Diff
$cmin :: Diff -> Diff -> Diff
max :: Diff -> Diff -> Diff
$cmax :: Diff -> Diff -> Diff
>= :: Diff -> Diff -> Bool
$c>= :: Diff -> Diff -> Bool
> :: Diff -> Diff -> Bool
$c> :: Diff -> Diff -> Bool
<= :: Diff -> Diff -> Bool
$c<= :: Diff -> Diff -> Bool
< :: Diff -> Diff -> Bool
$c< :: Diff -> Diff -> Bool
compare :: Diff -> Diff -> Ordering
$ccompare :: Diff -> Diff -> Ordering
$cp1Ord :: Eq Diff
Ord, Int -> Diff -> ShowS
[Diff] -> ShowS
Diff -> String
(Int -> Diff -> ShowS)
-> (Diff -> String) -> ([Diff] -> ShowS) -> Show Diff
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Diff] -> ShowS
$cshowList :: [Diff] -> ShowS
show :: Diff -> String
$cshow :: Diff -> String
showsPrec :: Int -> Diff -> ShowS
$cshowsPrec :: Int -> Diff -> ShowS
Show)

-- | @a :- b :<= k@ represents /a - b ≤ k/
data SimpleAtom b = Diff :<= b
  deriving (SimpleAtom b -> SimpleAtom b -> Bool
(SimpleAtom b -> SimpleAtom b -> Bool)
-> (SimpleAtom b -> SimpleAtom b -> Bool) -> Eq (SimpleAtom b)
forall b. Eq b => SimpleAtom b -> SimpleAtom b -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SimpleAtom b -> SimpleAtom b -> Bool
$c/= :: forall b. Eq b => SimpleAtom b -> SimpleAtom b -> Bool
== :: SimpleAtom b -> SimpleAtom b -> Bool
$c== :: forall b. Eq b => SimpleAtom b -> SimpleAtom b -> Bool
Eq, Eq (SimpleAtom b)
Eq (SimpleAtom b)
-> (SimpleAtom b -> SimpleAtom b -> Ordering)
-> (SimpleAtom b -> SimpleAtom b -> Bool)
-> (SimpleAtom b -> SimpleAtom b -> Bool)
-> (SimpleAtom b -> SimpleAtom b -> Bool)
-> (SimpleAtom b -> SimpleAtom b -> Bool)
-> (SimpleAtom b -> SimpleAtom b -> SimpleAtom b)
-> (SimpleAtom b -> SimpleAtom b -> SimpleAtom b)
-> Ord (SimpleAtom b)
SimpleAtom b -> SimpleAtom b -> Bool
SimpleAtom b -> SimpleAtom b -> Ordering
SimpleAtom b -> SimpleAtom b -> SimpleAtom b
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall b. Ord b => Eq (SimpleAtom b)
forall b. Ord b => SimpleAtom b -> SimpleAtom b -> Bool
forall b. Ord b => SimpleAtom b -> SimpleAtom b -> Ordering
forall b. Ord b => SimpleAtom b -> SimpleAtom b -> SimpleAtom b
min :: SimpleAtom b -> SimpleAtom b -> SimpleAtom b
$cmin :: forall b. Ord b => SimpleAtom b -> SimpleAtom b -> SimpleAtom b
max :: SimpleAtom b -> SimpleAtom b -> SimpleAtom b
$cmax :: forall b. Ord b => SimpleAtom b -> SimpleAtom b -> SimpleAtom b
>= :: SimpleAtom b -> SimpleAtom b -> Bool
$c>= :: forall b. Ord b => SimpleAtom b -> SimpleAtom b -> Bool
> :: SimpleAtom b -> SimpleAtom b -> Bool
$c> :: forall b. Ord b => SimpleAtom b -> SimpleAtom b -> Bool
<= :: SimpleAtom b -> SimpleAtom b -> Bool
$c<= :: forall b. Ord b => SimpleAtom b -> SimpleAtom b -> Bool
< :: SimpleAtom b -> SimpleAtom b -> Bool
$c< :: forall b. Ord b => SimpleAtom b -> SimpleAtom b -> Bool
compare :: SimpleAtom b -> SimpleAtom b -> Ordering
$ccompare :: forall b. Ord b => SimpleAtom b -> SimpleAtom b -> Ordering
$cp1Ord :: forall b. Ord b => Eq (SimpleAtom b)
Ord, Int -> SimpleAtom b -> ShowS
[SimpleAtom b] -> ShowS
SimpleAtom b -> String
(Int -> SimpleAtom b -> ShowS)
-> (SimpleAtom b -> String)
-> ([SimpleAtom b] -> ShowS)
-> Show (SimpleAtom b)
forall b. Show b => Int -> SimpleAtom b -> ShowS
forall b. Show b => [SimpleAtom b] -> ShowS
forall b. Show b => SimpleAtom b -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SimpleAtom b] -> ShowS
$cshowList :: forall b. Show b => [SimpleAtom b] -> ShowS
show :: SimpleAtom b -> String
$cshow :: forall b. Show b => SimpleAtom b -> String
showsPrec :: Int -> SimpleAtom b -> ShowS
$cshowsPrec :: forall b. Show b => Int -> SimpleAtom b -> ShowS
Show)

-- | Takes labeled list of constraints, and returns eithera
--
-- * unsatisfiable set of constraints as a set of labels, or
--
-- * satisfying assignment.
solve
  :: (Hashable label, Eq label, Real b)
  => [(label, SimpleAtom b)]
  -> Either (HashSet label) (IntMap b)
solve :: [(label, SimpleAtom b)] -> Either (HashSet label) (IntMap b)
solve [(label, SimpleAtom b)]
xs =
  case Fold b label (Endo [label])
-> Graph b label
-> IntMap (b, Last (InEdge b label))
-> Maybe (Endo [label])
forall cost label a.
Real cost =>
Fold cost label a
-> Graph cost label
-> IntMap (cost, Last (InEdge cost label))
-> Maybe a
bellmanFordDetectNegativeCycle ((Edge b label -> Endo [label]) -> Fold b label (Endo [label])
forall m cost label.
Monoid m =>
(Edge cost label -> m) -> Fold cost label m
monoid' (\(Int
_,Int
_,b
_,label
l) -> ([label] -> [label]) -> Endo [label]
forall a. (a -> a) -> Endo a
Endo (label
llabel -> [label] -> [label]
forall a. a -> [a] -> [a]
:))) Graph b label
g IntMap (b, Last (InEdge b label))
d of
    Just Endo [label]
f -> HashSet label -> Either (HashSet label) (IntMap b)
forall a b. a -> Either a b
Left (HashSet label -> Either (HashSet label) (IntMap b))
-> HashSet label -> Either (HashSet label) (IntMap b)
forall a b. (a -> b) -> a -> b
$ [label] -> HashSet label
forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList ([label] -> HashSet label) -> [label] -> HashSet label
forall a b. (a -> b) -> a -> b
$ Endo [label] -> [label] -> [label]
forall a. Endo a -> a -> a
appEndo Endo [label]
f []
    Maybe (Endo [label])
Nothing -> IntMap b -> Either (HashSet label) (IntMap b)
forall a b. b -> Either a b
Right (IntMap b -> Either (HashSet label) (IntMap b))
-> IntMap b -> Either (HashSet label) (IntMap b)
forall a b. (a -> b) -> a -> b
$ ((b, Last (InEdge b label)) -> b)
-> IntMap (b, Last (InEdge b label)) -> IntMap b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(b
c,Last (InEdge b label)
_) -> - b
c) IntMap (b, Last (InEdge b label))
d
  where
    vs :: [Int]
vs = HashSet Int -> [Int]
forall a. HashSet a -> [a]
HashSet.toList (HashSet Int -> [Int]) -> HashSet Int -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> HashSet Int
forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList [Int
v | (label
_,(Int
a :- Int
b :<= b
_)) <- [(label, SimpleAtom b)]
xs, Int
v <- [Int
a,Int
b]]
    g :: Graph b label
g = [(Int, [InEdge b label])] -> Graph b label
forall a. [(Int, a)] -> IntMap a
IntMap.fromList [(Int
a,[(Int
b,b
k,label
l)]) | (label
l,(Int
a :- Int
b :<= b
k)) <- [(label, SimpleAtom b)]
xs]
    d :: IntMap (b, Last (InEdge b label))
d = Fold b label (Last (InEdge b label))
-> Graph b label -> [Int] -> IntMap (b, Last (InEdge b label))
forall cost label a.
Real cost =>
Fold cost label a -> Graph cost label -> [Int] -> IntMap (cost, a)
bellmanFord Fold b label (Last (InEdge b label))
forall cost label. Fold cost label (Last (InEdge cost label))
lastInEdge Graph b label
g [Int]
vs

-- M = {a−b ≤ 2, b−c ≤ 3, c−a ≤ −3}
_test_sat :: Either (HashSet Int) (IntMap Int)
_test_sat :: Either (HashSet Int) (IntMap Int)
_test_sat = [(Int, SimpleAtom Int)] -> Either (HashSet Int) (IntMap Int)
forall label b.
(Hashable label, Eq label, Real b) =>
[(label, SimpleAtom b)] -> Either (HashSet label) (IntMap b)
solve [(Int, SimpleAtom Int)]
xs
  where
    xs :: [(Int, SimpleAtom Int)]
    xs :: [(Int, SimpleAtom Int)]
xs = [(Int
1, (Int
a Int -> Int -> Diff
:- Int
b Diff -> Int -> SimpleAtom Int
forall b. Diff -> b -> SimpleAtom b
:<= Int
2)), (Int
2, (Int
b Int -> Int -> Diff
:- Int
c Diff -> Int -> SimpleAtom Int
forall b. Diff -> b -> SimpleAtom b
:<= Int
3)), (Int
3, (Int
c Int -> Int -> Diff
:- Int
a Diff -> Int -> SimpleAtom Int
forall b. Diff -> b -> SimpleAtom b
:<= -Int
3))]
    [Int
a,Int
b,Int
c] = [Int
0..Int
2]

-- M = {a−b ≤ 2, b−c ≤ 3, c−a ≤ −7}
_test_unsat :: Either (HashSet Int) (IntMap Int)
_test_unsat :: Either (HashSet Int) (IntMap Int)
_test_unsat = [(Int, SimpleAtom Int)] -> Either (HashSet Int) (IntMap Int)
forall label b.
(Hashable label, Eq label, Real b) =>
[(label, SimpleAtom b)] -> Either (HashSet label) (IntMap b)
solve [(Int, SimpleAtom Int)]
xs
  where
    xs :: [(Int, SimpleAtom Int)]
    xs :: [(Int, SimpleAtom Int)]
xs = [(Int
1, (Int
a Int -> Int -> Diff
:- Int
b Diff -> Int -> SimpleAtom Int
forall b. Diff -> b -> SimpleAtom b
:<= Int
2)), (Int
2, (Int
b Int -> Int -> Diff
:- Int
c Diff -> Int -> SimpleAtom Int
forall b. Diff -> b -> SimpleAtom b
:<= Int
3)), (Int
3, (Int
c Int -> Int -> Diff
:- Int
a Diff -> Int -> SimpleAtom Int
forall b. Diff -> b -> SimpleAtom b
:<= -Int
7))]
    [Int
a,Int
b,Int
c] = [Int
0..Int
2]