{-|
Module      : What4.Expr.BoolMap
Description : Datastructure for representing a conjunction of predicates
Copyright   : (c) Galois Inc, 2019-2020
License     : BSD3
Maintainer  : rdockins@galois.com

Declares a datatype for representing n-way conjunctions or disjunctions
in a way that efficiently captures important algebraic
laws like commutativity, associativity and resolution.
-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ViewPatterns #-}
module What4.Expr.BoolMap
  ( BoolMap
  , var
  , addVar
  , fromVars
  , combine
  , Polarity(..)
  , negatePolarity
  , contains
  , isInconsistent
  , isNull
  , BoolMapView(..)
  , viewBoolMap
  , traverseVars
  , reversePolarities
  , removeVar
  , Wrap(..)
  ) where

import           Control.Lens (_1, over)
import           Data.Hashable
import           Data.List (foldl')
import           Data.List.NonEmpty (NonEmpty(..))
import           Data.Kind (Type)
import           Data.Parameterized.Classes

import           What4.BaseTypes
import qualified What4.Utils.AnnotatedMap as AM
import           What4.Utils.IncrHash

-- | Describes the occurrence of a variable or expression, whether it is
--   negated or not.
data Polarity = Positive | Negative
 deriving (Polarity -> Polarity -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Polarity -> Polarity -> Bool
$c/= :: Polarity -> Polarity -> Bool
== :: Polarity -> Polarity -> Bool
$c== :: Polarity -> Polarity -> Bool
Eq,Eq Polarity
Polarity -> Polarity -> Bool
Polarity -> Polarity -> Ordering
Polarity -> Polarity -> Polarity
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Polarity -> Polarity -> Polarity
$cmin :: Polarity -> Polarity -> Polarity
max :: Polarity -> Polarity -> Polarity
$cmax :: Polarity -> Polarity -> Polarity
>= :: Polarity -> Polarity -> Bool
$c>= :: Polarity -> Polarity -> Bool
> :: Polarity -> Polarity -> Bool
$c> :: Polarity -> Polarity -> Bool
<= :: Polarity -> Polarity -> Bool
$c<= :: Polarity -> Polarity -> Bool
< :: Polarity -> Polarity -> Bool
$c< :: Polarity -> Polarity -> Bool
compare :: Polarity -> Polarity -> Ordering
$ccompare :: Polarity -> Polarity -> Ordering
Ord,Int -> Polarity -> ShowS
[Polarity] -> ShowS
Polarity -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Polarity] -> ShowS
$cshowList :: [Polarity] -> ShowS
show :: Polarity -> String
$cshow :: Polarity -> String
showsPrec :: Int -> Polarity -> ShowS
$cshowsPrec :: Int -> Polarity -> ShowS
Show)

instance Hashable Polarity where
  hashWithSalt :: Int -> Polarity -> Int
hashWithSalt Int
s Polarity
Positive = forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
0::Int)
  hashWithSalt Int
s Polarity
Negative = forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
1::Int)

-- | Swap a polarity value
negatePolarity :: Polarity -> Polarity
negatePolarity :: Polarity -> Polarity
negatePolarity Polarity
Positive = Polarity
Negative
negatePolarity Polarity
Negative = Polarity
Positive

newtype Wrap (f :: k -> Type) (x :: k) = Wrap { forall k (f :: k -> Type) (x :: k). Wrap f x -> f x
unWrap:: f x }

instance TestEquality f => Eq (Wrap f x) where
  Wrap f x
a == :: Wrap f x -> Wrap f x -> Bool
== Wrap f x
b = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality f x
a f x
b
instance OrdF f => Ord (Wrap f x) where
  compare :: Wrap f x -> Wrap f x -> Ordering
compare (Wrap f x
a) (Wrap f x
b) = forall {k} (x :: k) (y :: k). OrderingF x y -> Ordering
toOrdering forall a b. (a -> b) -> a -> b
$ forall k (ktp :: k -> Type) (x :: k) (y :: k).
OrdF ktp =>
ktp x -> ktp y -> OrderingF x y
compareF f x
a f x
b
instance (HashableF f, TestEquality f) => Hashable (Wrap f x) where
  hashWithSalt :: Int -> Wrap f x -> Int
hashWithSalt Int
s (Wrap f x
a) = forall k (f :: k -> Type) (tp :: k).
HashableF f =>
Int -> f tp -> Int
hashWithSaltF Int
s f x
a

-- | This data structure keeps track of a collection of expressions
--   together with their polarities. Such a collection might represent
--   either a conjunction or a disjunction of expressions.  The
--   implementation uses a map from expression values to their
--   polarities, and thus automatically implements the associative,
--   commutative and idempotency laws common to both conjunctions and
--   disjunctions.  Moreover, if the same expression occurs in the
--   collection with opposite polarities, the entire collection
--   collapses via a resolution step to an \"inconsistent\" map.  For
--   conjunctions this corresponds to a contradiction and
--   represents false; for disjunction, this corresponds to the law of
--   the excluded middle and represents true.

data BoolMap (f :: BaseType -> Type)
  = InconsistentMap
  | BoolMap !(AM.AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity)

instance OrdF f => Eq (BoolMap f) where
  BoolMap f
InconsistentMap == :: BoolMap f -> BoolMap f -> Bool
== BoolMap f
InconsistentMap = Bool
True
  BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m1 == BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m2 = forall k a v.
Eq k =>
(a -> a -> Bool)
-> AnnotatedMap k v a -> AnnotatedMap k v a -> Bool
AM.eqBy forall a. Eq a => a -> a -> Bool
(==) AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m1 AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m2
  BoolMap f
_ == BoolMap f
_ = Bool
False


-- | Traverse the expressions in a bool map, and rebuild the map.
traverseVars :: (Applicative m, HashableF g, OrdF g) =>
  (f BaseBoolType -> m (g (BaseBoolType))) ->
  BoolMap f -> m (BoolMap g)
traverseVars :: forall (m :: Type -> Type) (g :: BaseType -> Type)
       (f :: BaseType -> Type).
(Applicative m, HashableF g, OrdF g) =>
(f BaseBoolType -> m (g BaseBoolType))
-> BoolMap f -> m (BoolMap g)
traverseVars f BaseBoolType -> m (g BaseBoolType)
_ BoolMap f
InconsistentMap = forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (f :: BaseType -> Type). BoolMap f
InconsistentMap
traverseVars f BaseBoolType -> m (g BaseBoolType)
f (BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m) =
  forall (f :: BaseType -> Type).
(HashableF f, OrdF f) =>
[(f BaseBoolType, Polarity)] -> BoolMap f
fromVars forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall s t a b. Field1 s t a b => Lens s t a b
_1 (f BaseBoolType -> m (g BaseBoolType)
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (f :: k -> Type) (x :: k). Wrap f x -> f x
unWrap)) (forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m)

elementHash :: HashableF f => f BaseBoolType -> Polarity -> IncrHash
elementHash :: forall (f :: BaseType -> Type).
HashableF f =>
f BaseBoolType -> Polarity -> IncrHash
elementHash f BaseBoolType
x Polarity
p = Int -> IncrHash
mkIncrHash (forall k (f :: k -> Type) (tp :: k).
HashableF f =>
Int -> f tp -> Int
hashWithSaltF (forall a. Hashable a => a -> Int
hash Polarity
p) f BaseBoolType
x)

instance (OrdF f, HashableF f) => Hashable (BoolMap f) where
  hashWithSalt :: Int -> BoolMap f -> Int
hashWithSalt Int
s BoolMap f
InconsistentMap = forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
0::Int)
  hashWithSalt Int
s (BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m) =
    case forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a -> Maybe v
AM.annotation AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m of
      Maybe IncrHash
Nothing -> forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
1::Int)
      Just IncrHash
h  -> forall a. Hashable a => Int -> a -> Int
hashWithSalt (forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
1::Int)) IncrHash
h

-- | Represents the state of a bool map
data BoolMapView f
  = BoolMapUnit
       -- ^ A bool map with no expressions, represents the unit of the corresponding operation
  | BoolMapDualUnit
       -- ^ An inconsistent bool map, represents the dual of the operation unit
  | BoolMapTerms (NonEmpty (f BaseBoolType, Polarity))
       -- ^ The terms appearing in the bool map, of which there is at least one

-- | Deconstruct the given bool map for later processing
viewBoolMap :: BoolMap f -> BoolMapView f
viewBoolMap :: forall (f :: BaseType -> Type). BoolMap f -> BoolMapView f
viewBoolMap BoolMap f
InconsistentMap = forall (f :: BaseType -> Type). BoolMapView f
BoolMapDualUnit
viewBoolMap (BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m) =
  case forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m of
    []  -> forall (f :: BaseType -> Type). BoolMapView f
BoolMapUnit
    (Wrap f BaseBoolType
x,Polarity
p):[(Wrap f BaseBoolType, Polarity)]
xs -> forall (f :: BaseType -> Type).
NonEmpty (f BaseBoolType, Polarity) -> BoolMapView f
BoolMapTerms ((f BaseBoolType
x,Polarity
p)forall a. a -> [a] -> NonEmpty a
:|(forall a b. (a -> b) -> [a] -> [b]
map (forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over forall s t a b. Field1 s t a b => Lens s t a b
_1 forall k (f :: k -> Type) (x :: k). Wrap f x -> f x
unWrap) [(Wrap f BaseBoolType, Polarity)]
xs))

-- | Returns true for an inconsistent bool map
isInconsistent :: BoolMap f -> Bool
isInconsistent :: forall (f :: BaseType -> Type). BoolMap f -> Bool
isInconsistent BoolMap f
InconsistentMap = Bool
True
isInconsistent BoolMap f
_ = Bool
False

-- | Returns true for a \"null\" bool map with no terms
isNull :: BoolMap f -> Bool
isNull :: forall (f :: BaseType -> Type). BoolMap f -> Bool
isNull BoolMap f
InconsistentMap = Bool
False
isNull (BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m) = forall k v a. AnnotatedMap k v a -> Bool
AM.null AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m

-- | Produce a singleton bool map, consisting of just the given term
var :: (HashableF f, OrdF f) => f BaseBoolType -> Polarity -> BoolMap f
var :: forall (f :: BaseType -> Type).
(HashableF f, OrdF f) =>
f BaseBoolType -> Polarity -> BoolMap f
var f BaseBoolType
x Polarity
p = forall (f :: BaseType -> Type).
AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity -> BoolMap f
BoolMap (forall k v a.
(Ord k, Semigroup v) =>
k -> v -> a -> AnnotatedMap k v a
AM.singleton (forall k (f :: k -> Type) (x :: k). f x -> Wrap f x
Wrap f BaseBoolType
x) (forall (f :: BaseType -> Type).
HashableF f =>
f BaseBoolType -> Polarity -> IncrHash
elementHash f BaseBoolType
x Polarity
p) Polarity
p)

-- | Add a variable to a bool map, performing a resolution step if possible
addVar :: (HashableF f, OrdF f) => f BaseBoolType -> Polarity -> BoolMap f -> BoolMap f
addVar :: forall (f :: BaseType -> Type).
(HashableF f, OrdF f) =>
f BaseBoolType -> Polarity -> BoolMap f -> BoolMap f
addVar f BaseBoolType
_ Polarity
_ BoolMap f
InconsistentMap = forall (f :: BaseType -> Type). BoolMap f
InconsistentMap
addVar f BaseBoolType
x Polarity
p1 (BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
bm) = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall (f :: BaseType -> Type). BoolMap f
InconsistentMap forall (f :: BaseType -> Type).
AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity -> BoolMap f
BoolMap forall a b. (a -> b) -> a -> b
$ forall (f :: Type -> Type) k v a.
(Functor f, Ord k, Semigroup v) =>
(Maybe (v, a) -> f (Maybe (v, a)))
-> k -> AnnotatedMap k v a -> f (AnnotatedMap k v a)
AM.alterF Maybe (IncrHash, Polarity) -> Maybe (Maybe (IncrHash, Polarity))
f (forall k (f :: k -> Type) (x :: k). f x -> Wrap f x
Wrap f BaseBoolType
x) AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
bm
 where
 f :: Maybe (IncrHash, Polarity) -> Maybe (Maybe (IncrHash, Polarity))
f Maybe (IncrHash, Polarity)
Nothing = forall (m :: Type -> Type) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just (forall (f :: BaseType -> Type).
HashableF f =>
f BaseBoolType -> Polarity -> IncrHash
elementHash f BaseBoolType
x Polarity
p1, Polarity
p1))
 f el :: Maybe (IncrHash, Polarity)
el@(Just (IncrHash
_,Polarity
p2)) | Polarity
p1 forall a. Eq a => a -> a -> Bool
== Polarity
p2  = forall (m :: Type -> Type) a. Monad m => a -> m a
return Maybe (IncrHash, Polarity)
el
                    | Bool
otherwise = forall a. Maybe a
Nothing

-- | Generate a bool map from a list of terms and polarities by repeatedly
--   calling @addVar@.
fromVars :: (HashableF f, OrdF f) => [(f BaseBoolType, Polarity)] -> BoolMap f
fromVars :: forall (f :: BaseType -> Type).
(HashableF f, OrdF f) =>
[(f BaseBoolType, Polarity)] -> BoolMap f
fromVars = forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\BoolMap f
m (f BaseBoolType
x,Polarity
p) -> forall (f :: BaseType -> Type).
(HashableF f, OrdF f) =>
f BaseBoolType -> Polarity -> BoolMap f -> BoolMap f
addVar f BaseBoolType
x Polarity
p BoolMap f
m) (forall (f :: BaseType -> Type).
AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity -> BoolMap f
BoolMap forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty)

-- | Merge two bool maps, performing resolution as necessary.
combine :: OrdF f => BoolMap f -> BoolMap f -> BoolMap f
combine :: forall (f :: BaseType -> Type).
OrdF f =>
BoolMap f -> BoolMap f -> BoolMap f
combine BoolMap f
InconsistentMap BoolMap f
_ = forall (f :: BaseType -> Type). BoolMap f
InconsistentMap
combine BoolMap f
_ BoolMap f
InconsistentMap = forall (f :: BaseType -> Type). BoolMap f
InconsistentMap
combine (BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m1) (BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m2) =
    forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall (f :: BaseType -> Type). BoolMap f
InconsistentMap forall (f :: BaseType -> Type).
AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity -> BoolMap f
BoolMap forall a b. (a -> b) -> a -> b
$ forall k v (f :: Type -> Type) a.
(Ord k, Semigroup v, Applicative f) =>
(k -> (v, a) -> (v, a) -> f (v, a))
-> AnnotatedMap k v a
-> AnnotatedMap k v a
-> f (AnnotatedMap k v a)
AM.mergeA forall {b} {p} {a} {a}.
Eq b =>
p -> (a, b) -> (a, b) -> Maybe (a, b)
f AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m1 AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m2

  where f :: p -> (a, b) -> (a, b) -> Maybe (a, b)
f p
_k (a
v,b
p1) (a
_,b
p2)
          | b
p1 forall a. Eq a => a -> a -> Bool
== b
p2  = forall a. a -> Maybe a
Just (a
v,b
p1)
          | Bool
otherwise = forall a. Maybe a
Nothing

-- | Test if the bool map contains the given term, and return the polarity
--   of that term if so.
contains :: OrdF f => BoolMap f -> f BaseBoolType -> Maybe Polarity
contains :: forall (f :: BaseType -> Type).
OrdF f =>
BoolMap f -> f BaseBoolType -> Maybe Polarity
contains BoolMap f
InconsistentMap f BaseBoolType
_ = forall a. Maybe a
Nothing
contains (BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m) f BaseBoolType
x = forall a b. (a, b) -> b
snd forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k v a.
(Ord k, Semigroup v) =>
k -> AnnotatedMap k v a -> Maybe (v, a)
AM.lookup (forall k (f :: k -> Type) (x :: k). f x -> Wrap f x
Wrap f BaseBoolType
x) AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m

-- | Swap the polarities of the terms in the given bool map.
reversePolarities :: OrdF f => BoolMap f -> BoolMap f
reversePolarities :: forall (f :: BaseType -> Type). OrdF f => BoolMap f -> BoolMap f
reversePolarities BoolMap f
InconsistentMap = forall (f :: BaseType -> Type). BoolMap f
InconsistentMap
reversePolarities (BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m) = forall (f :: BaseType -> Type).
AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity -> BoolMap f
BoolMap forall a b. (a -> b) -> a -> b
$! forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Polarity -> Polarity
negatePolarity AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m

-- | Remove the given term from the bool map.  The map is unchanged
--   if inconsistent or if the term does not occur.
removeVar :: OrdF f => BoolMap f -> f BaseBoolType -> BoolMap f
removeVar :: forall (f :: BaseType -> Type).
OrdF f =>
BoolMap f -> f BaseBoolType -> BoolMap f
removeVar BoolMap f
InconsistentMap f BaseBoolType
_ = forall (f :: BaseType -> Type). BoolMap f
InconsistentMap
removeVar (BoolMap AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m) f BaseBoolType
x = forall (f :: BaseType -> Type).
AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity -> BoolMap f
BoolMap (forall k v a.
(Ord k, Semigroup v) =>
k -> AnnotatedMap k v a -> AnnotatedMap k v a
AM.delete (forall k (f :: k -> Type) (x :: k). f x -> Wrap f x
Wrap f BaseBoolType
x) AnnotatedMap (Wrap f BaseBoolType) IncrHash Polarity
m)