{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Core.Data.Class.Bool
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.Data.Class.Bool
  ( -- * Symbolic equality
    SEq (..),
    SEq' (..),

    -- * Symbolic Boolean operations
    LogicalOp (..),
    SymBoolOp,
    ITEOp (..),
  )
where

import Control.Monad.Except
import Control.Monad.Identity
  ( Identity (Identity),
    IdentityT (IdentityT),
  )
import Control.Monad.Trans.Maybe
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import qualified Data.ByteString as B
import Data.Functor.Sum
import Data.Int
import Data.Word
import GHC.TypeNats
import Generics.Deriving
import Grisette.Core.Data.BV
import {-# SOURCE #-} Grisette.Core.Data.Class.SimpleMergeable
import Grisette.Core.Data.Class.Solvable
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term
import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool
import {-# SOURCE #-} Grisette.IR.SymPrim.Data.SymPrim

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.IR.SymPrim
-- >>> :set -XDataKinds
-- >>> :set -XBinaryLiterals
-- >>> :set -XFlexibleContexts
-- >>> :set -XFlexibleInstances
-- >>> :set -XFunctionalDependencies

-- | Auxiliary class for 'SEq' instance derivation
class SEq' f where
  -- | Auxiliary function for '(==~~) derivation
  (==~~) :: f a -> f a -> SymBool

  infix 4 ==~~

instance SEq' U1 where
  U1 a
_ ==~~ :: forall a. U1 a -> U1 a -> SymBool
==~~ U1 a
_ = forall c t. Solvable c t => c -> t
con Bool
True
  {-# INLINE (==~~) #-}

instance SEq' V1 where
  V1 a
_ ==~~ :: forall a. V1 a -> V1 a -> SymBool
==~~ V1 a
_ = forall c t. Solvable c t => c -> t
con Bool
True
  {-# INLINE (==~~) #-}

instance SEq c => SEq' (K1 i c) where
  (K1 c
a) ==~~ :: forall a. K1 i c a -> K1 i c a -> SymBool
==~~ (K1 c
b) = c
a forall a. SEq a => a -> a -> SymBool
==~ c
b
  {-# INLINE (==~~) #-}

instance SEq' a => SEq' (M1 i c a) where
  (M1 a a
a) ==~~ :: forall a. M1 i c a a -> M1 i c a a -> SymBool
==~~ (M1 a a
b) = a a
a forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
==~~ a a
b
  {-# INLINE (==~~) #-}

instance (SEq' a, SEq' b) => SEq' (a :+: b) where
  (L1 a a
a) ==~~ :: forall a. (:+:) a b a -> (:+:) a b a -> SymBool
==~~ (L1 a a
b) = a a
a forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
==~~ a a
b
  (R1 b a
a) ==~~ (R1 b a
b) = b a
a forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
==~~ b a
b
  (:+:) a b a
_ ==~~ (:+:) a b a
_ = forall c t. Solvable c t => c -> t
con Bool
False
  {-# INLINE (==~~) #-}

instance (SEq' a, SEq' b) => SEq' (a :*: b) where
  (a a
a1 :*: b a
b1) ==~~ :: forall a. (:*:) a b a -> (:*:) a b a -> SymBool
==~~ (a a
a2 :*: b a
b2) = (a a
a1 forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
==~~ a a
a2) forall b. LogicalOp b => b -> b -> b
&&~ (b a
b1 forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
==~~ b a
b2)
  {-# INLINE (==~~) #-}

-- | Symbolic equality. Note that we can't use Haskell's 'Eq' class since
-- symbolic comparison won't necessarily return a concrete 'Bool' value.
--
-- >>> let a = 1 :: SymInteger
-- >>> let b = 2 :: SymInteger
-- >>> a ==~ b
-- false
-- >>> a /=~ b
-- true
--
-- >>> let a = "a" :: SymInteger
-- >>> let b = "b" :: SymInteger
-- >>> a /=~ b
-- (! (= a b))
-- >>> a /=~ b
-- (! (= a b))
--
-- __Note:__ This type class can be derived for algebraic data types.
-- You may need the @DerivingVia@ and @DerivingStrategies@ extensions.
--
-- > data X = ... deriving Generic deriving SEq via (Default X)
class SEq a where
  (==~) :: a -> a -> SymBool
  a
a ==~ a
b = forall b. LogicalOp b => b -> b
nots forall a b. (a -> b) -> a -> b
$ a
a forall a. SEq a => a -> a -> SymBool
/=~ a
b
  {-# INLINE (==~) #-}
  infix 4 ==~

  (/=~) :: a -> a -> SymBool
  a
a /=~ a
b = forall b. LogicalOp b => b -> b
nots forall a b. (a -> b) -> a -> b
$ a
a forall a. SEq a => a -> a -> SymBool
==~ a
b
  {-# INLINE (/=~) #-}
  infix 4 /=~
  {-# MINIMAL (==~) | (/=~) #-}

instance (Generic a, SEq' (Rep a)) => SEq (Default a) where
  Default a
l ==~ :: Default a -> Default a -> SymBool
==~ Default a
r = forall a x. Generic a => a -> Rep a x
from a
l forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
==~~ forall a x. Generic a => a -> Rep a x
from a
r
  {-# INLINE (==~) #-}

-- | Symbolic logical operators for symbolic booleans.
--
-- >>> let t = con True :: SymBool
-- >>> let f = con False :: SymBool
-- >>> let a = "a" :: SymBool
-- >>> let b = "b" :: SymBool
-- >>> t ||~ f
-- true
-- >>> a ||~ t
-- true
-- >>> a ||~ f
-- a
-- >>> a ||~ b
-- (|| a b)
-- >>> t &&~ f
-- false
-- >>> a &&~ t
-- a
-- >>> a &&~ f
-- false
-- >>> a &&~ b
-- (&& a b)
-- >>> nots t
-- false
-- >>> nots f
-- true
-- >>> nots a
-- (! a)
-- >>> t `xors` f
-- true
-- >>> t `xors` t
-- false
-- >>> a `xors` t
-- (! a)
-- >>> a `xors` f
-- a
-- >>> a `xors` b
-- (|| (&& (! a) b) (&& a (! b)))
class LogicalOp b where
  -- | Symbolic disjunction
  (||~) :: b -> b -> b
  b
a ||~ b
b = forall b. LogicalOp b => b -> b
nots forall a b. (a -> b) -> a -> b
$ forall b. LogicalOp b => b -> b
nots b
a forall b. LogicalOp b => b -> b -> b
&&~ forall b. LogicalOp b => b -> b
nots b
b
  {-# INLINE (||~) #-}

  infixr 2 ||~

  -- | Symbolic conjunction
  (&&~) :: b -> b -> b
  b
a &&~ b
b = forall b. LogicalOp b => b -> b
nots forall a b. (a -> b) -> a -> b
$ forall b. LogicalOp b => b -> b
nots b
a forall b. LogicalOp b => b -> b -> b
||~ forall b. LogicalOp b => b -> b
nots b
b
  {-# INLINE (&&~) #-}

  infixr 3 &&~

  -- | Symbolic negation
  nots :: b -> b

  -- | Symbolic exclusive disjunction
  xors :: b -> b -> b
  b
a `xors` b
b = (b
a forall b. LogicalOp b => b -> b -> b
&&~ forall b. LogicalOp b => b -> b
nots b
b) forall b. LogicalOp b => b -> b -> b
||~ (forall b. LogicalOp b => b -> b
nots b
a forall b. LogicalOp b => b -> b -> b
&&~ b
b)
  {-# INLINE xors #-}

  -- | Symbolic implication
  implies :: b -> b -> b
  b
a `implies` b
b = forall b. LogicalOp b => b -> b
nots b
a forall b. LogicalOp b => b -> b -> b
||~ b
b
  {-# INLINE implies #-}

  {-# MINIMAL (||~), nots | (&&~), nots #-}

instance LogicalOp Bool where
  ||~ :: Bool -> Bool -> Bool
(||~) = Bool -> Bool -> Bool
(||)
  {-# INLINE (||~) #-}
  &&~ :: Bool -> Bool -> Bool
(&&~) = Bool -> Bool -> Bool
(&&)
  {-# INLINE (&&~) #-}
  nots :: Bool -> Bool
nots = Bool -> Bool
not
  {-# INLINE nots #-}

-- | ITE operator for solvable (see "Grisette.Core#solvable")s, including symbolic boolean, integer, etc.
--
-- >>> let a = "a" :: SymBool
-- >>> let b = "b" :: SymBool
-- >>> let c = "c" :: SymBool
-- >>> ites a b c
-- (ite a b c)
class ITEOp v where
  ites :: SymBool -> v -> v -> v

-- | Aggregation for the operations on symbolic boolean types
class (SimpleMergeable b, SEq b, Eq b, LogicalOp b, Solvable Bool b, ITEOp b) => SymBoolOp b

#define CONCRETE_SEQ(type) \
instance SEq type where \
  l ==~ r = con $ l == r; \
  {-# INLINE (==~) #-}

#define CONCRETE_SEQ_BV(type) \
instance (KnownNat n, 1 <= n) => SEq (type n) where \
  l ==~ r = con $ l == r; \
  {-# INLINE (==~) #-}

#if 1
CONCRETE_SEQ(Bool)
CONCRETE_SEQ(Integer)
CONCRETE_SEQ(Char)
CONCRETE_SEQ(Int)
CONCRETE_SEQ(Int8)
CONCRETE_SEQ(Int16)
CONCRETE_SEQ(Int32)
CONCRETE_SEQ(Int64)
CONCRETE_SEQ(Word)
CONCRETE_SEQ(Word8)
CONCRETE_SEQ(Word16)
CONCRETE_SEQ(Word32)
CONCRETE_SEQ(Word64)
CONCRETE_SEQ(B.ByteString)
CONCRETE_SEQ_BV(WordN)
CONCRETE_SEQ_BV(IntN)
CONCRETE_SEQ(SomeWordN)
CONCRETE_SEQ(SomeIntN)
#endif

-- List
deriving via (Default [a]) instance (SEq a) => SEq [a]

-- Maybe
deriving via (Default (Maybe a)) instance (SEq a) => SEq (Maybe a)

-- Either
deriving via (Default (Either e a)) instance (SEq e, SEq a) => SEq (Either e a)

-- ExceptT
instance (SEq (m (Either e a))) => SEq (ExceptT e m a) where
  (ExceptT m (Either e a)
a) ==~ :: ExceptT e m a -> ExceptT e m a -> SymBool
==~ (ExceptT m (Either e a)
b) = m (Either e a)
a forall a. SEq a => a -> a -> SymBool
==~ m (Either e a)
b
  {-# INLINE (==~) #-}

-- MaybeT
instance (SEq (m (Maybe a))) => SEq (MaybeT m a) where
  (MaybeT m (Maybe a)
a) ==~ :: MaybeT m a -> MaybeT m a -> SymBool
==~ (MaybeT m (Maybe a)
b) = m (Maybe a)
a forall a. SEq a => a -> a -> SymBool
==~ m (Maybe a)
b
  {-# INLINE (==~) #-}

-- ()
instance SEq () where
  ()
_ ==~ :: () -> () -> SymBool
==~ ()
_ = forall c t. Solvable c t => c -> t
con Bool
True
  {-# INLINE (==~) #-}

-- (,)
deriving via (Default (a, b)) instance (SEq a, SEq b) => SEq (a, b)

-- (,,)
deriving via (Default (a, b, c)) instance (SEq a, SEq b, SEq c) => SEq (a, b, c)

-- (,,,)
deriving via
  (Default (a, b, c, d))
  instance
    (SEq a, SEq b, SEq c, SEq d) =>
    SEq (a, b, c, d)

-- (,,,,)
deriving via
  (Default (a, b, c, d, e))
  instance
    (SEq a, SEq b, SEq c, SEq d, SEq e) =>
    SEq (a, b, c, d, e)

-- (,,,,,)
deriving via
  (Default (a, b, c, d, e, f))
  instance
    (SEq a, SEq b, SEq c, SEq d, SEq e, SEq f) =>
    SEq (a, b, c, d, e, f)

-- (,,,,,,)
deriving via
  (Default (a, b, c, d, e, f, g))
  instance
    (SEq a, SEq b, SEq c, SEq d, SEq e, SEq f, SEq g) =>
    SEq (a, b, c, d, e, f, g)

-- (,,,,,,,)
deriving via
  (Default (a, b, c, d, e, f, g, h))
  instance
    (SEq a, SEq b, SEq c, SEq d, SEq e, SEq f, SEq g, SEq h) =>
    SEq (a, b, c, d, e, f, g, h)

-- Sum
deriving via
  (Default (Sum f g a))
  instance
    (SEq (f a), SEq (g a)) => SEq (Sum f g a)

-- Writer
instance (SEq (m (a, s))) => SEq (WriterLazy.WriterT s m a) where
  (WriterLazy.WriterT m (a, s)
l) ==~ :: WriterT s m a -> WriterT s m a -> SymBool
==~ (WriterLazy.WriterT m (a, s)
r) = m (a, s)
l forall a. SEq a => a -> a -> SymBool
==~ m (a, s)
r
  {-# INLINE (==~) #-}

instance (SEq (m (a, s))) => SEq (WriterStrict.WriterT s m a) where
  (WriterStrict.WriterT m (a, s)
l) ==~ :: WriterT s m a -> WriterT s m a -> SymBool
==~ (WriterStrict.WriterT m (a, s)
r) = m (a, s)
l forall a. SEq a => a -> a -> SymBool
==~ m (a, s)
r
  {-# INLINE (==~) #-}

-- Identity
instance (SEq a) => SEq (Identity a) where
  (Identity a
l) ==~ :: Identity a -> Identity a -> SymBool
==~ (Identity a
r) = a
l forall a. SEq a => a -> a -> SymBool
==~ a
r
  {-# INLINE (==~) #-}

-- IdentityT
instance (SEq (m a)) => SEq (IdentityT m a) where
  (IdentityT m a
l) ==~ :: IdentityT m a -> IdentityT m a -> SymBool
==~ (IdentityT m a
r) = m a
l forall a. SEq a => a -> a -> SymBool
==~ m a
r
  {-# INLINE (==~) #-}

#define ITEOP_SIMPLE(type) \
instance ITEOp type where \
  ites (SymBool c) (type t) (type f) = type $ pevalITETerm c t f; \
  {-# INLINE ites #-}

#define ITEOP_BV(type) \
instance (KnownNat n, 1 <= n) => ITEOp (type n) where \
  ites (SymBool c) (type t) (type f) = type $ pevalITETerm c t f; \
  {-# INLINE ites #-}

#define ITEOP_BV_SOME(symtype, bf) \
instance ITEOp symtype where \
  ites c = bf (ites c) "ites"; \
  {-# INLINE ites #-}

#define ITEOP_FUN(op, cons) \
instance (SupportedPrim ca, SupportedPrim cb, LinkedRep ca sa, LinkedRep cb sb) => ITEOp (sa op sb) where \
  ites (SymBool c) (cons t) (cons f) = cons $ pevalITETerm c t f; \
  {-# INLINE ites #-}

#if 1
ITEOP_SIMPLE(SymBool)
ITEOP_SIMPLE(SymInteger)
ITEOP_BV(SymIntN)
ITEOP_BV(SymWordN)
ITEOP_BV_SOME(SomeSymIntN, binSomeSymIntNR1)
ITEOP_BV_SOME(SomeSymWordN, binSomeSymWordNR1)
ITEOP_FUN(=~>, SymTabularFun)
ITEOP_FUN(-~>, SymGeneralFun)
#endif

instance LogicalOp SymBool where
  (SymBool Term Bool
l) ||~ :: SymBool -> SymBool -> SymBool
||~ (SymBool Term Bool
r) = Term Bool -> SymBool
SymBool forall a b. (a -> b) -> a -> b
$ Term Bool -> Term Bool -> Term Bool
pevalOrTerm Term Bool
l Term Bool
r
  (SymBool Term Bool
l) &&~ :: SymBool -> SymBool -> SymBool
&&~ (SymBool Term Bool
r) = Term Bool -> SymBool
SymBool forall a b. (a -> b) -> a -> b
$ Term Bool -> Term Bool -> Term Bool
pevalAndTerm Term Bool
l Term Bool
r
  nots :: SymBool -> SymBool
nots (SymBool Term Bool
v) = Term Bool -> SymBool
SymBool forall a b. (a -> b) -> a -> b
$ Term Bool -> Term Bool
pevalNotTerm Term Bool
v
  (SymBool Term Bool
l) xors :: SymBool -> SymBool -> SymBool
`xors` (SymBool Term Bool
r) = Term Bool -> SymBool
SymBool forall a b. (a -> b) -> a -> b
$ Term Bool -> Term Bool -> Term Bool
pevalXorTerm Term Bool
l Term Bool
r
  (SymBool Term Bool
l) implies :: SymBool -> SymBool -> SymBool
`implies` (SymBool Term Bool
r) = Term Bool -> SymBool
SymBool forall a b. (a -> b) -> a -> b
$ Term Bool -> Term Bool -> Term Bool
pevalImplyTerm Term Bool
l Term Bool
r