{-|
Module      : What4.SemiRing
Description : Definitions related to semiring structures over base types.
Copyright   : (c) Galois Inc, 2019-2020
License     : BSD3
Maintainer  : rdockins@galois.com

The algebraic assumptions we make about our semirings are that:

* addition is commutative and associative, with a unit called zero,
* multiplication is commutative and associative, with a unit called one,
* one and zero are distinct values,
* multiplication distributes through addition, and
* multiplication by zero gives zero.

Note that we do not assume the existence of additive inverses (hence,
semirings), but we do assume commutativity of multiplication.

Note, moreover, that bitvectors can be equipped with two different
semirings (the usual arithmetic one and the XOR/AND boolean ring imposed
by the boolean algebra structure), which occasionally requires some care.

In addition, some semirings are "ordered" semirings.  These are equipped
with a total ordering relation such that addition is both order-preserving
and order-reflecting; that is, @x <= y@ iff @x + z <= y + z@.
Moreover ordered semirings satisfy: @0 <= x@ and @0 <= y@ implies @0 <= x*y@.
-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module What4.SemiRing
  ( -- * Semiring datakinds
    type SemiRing
  , type SemiRingInteger
  , type SemiRingReal
  , type SemiRingBV
  , type BVFlavor
  , type BVBits
  , type BVArith

    -- * Semiring representations
  , SemiRingRepr(..)
  , OrderedSemiRingRepr(..)
  , BVFlavorRepr(..)
  , SemiRingBase
  , semiRingBase
  , orderedSemiRing

    -- * Semiring coefficients
  , Coefficient
  , zero
  , one
  , add
  , mul
  , eq
  , le
  , lt
  , sr_compare
  , sr_hashWithSalt

    -- * Semiring product occurrences
  , Occurrence
  , occ_add
  , occ_one
  , occ_eq
  , occ_hashWithSalt
  , occ_compare
  , occ_count
  ) where

import GHC.TypeNats (Nat)
import qualified Data.BitVector.Sized as BV
import Data.Kind
import Data.Hashable
import Data.Parameterized.Classes
import Data.Parameterized.TH.GADT
import Numeric.Natural (Natural)

import What4.BaseTypes

-- | Data-kind indicating the two flavors of bitvector semirings.
--   The ordinary arithmetic semiring consists of addition and multiplication,
--   and the "bits" semiring consists of bitwise xor and bitwise and.
data BVFlavor = BVArith | BVBits

-- | Data-kind representing the semirings What4 supports.
data SemiRing
  = SemiRingInteger
  | SemiRingReal
  | SemiRingBV BVFlavor Nat

type BVArith = 'BVArith    -- ^ @:: 'BVFlavor'@
type BVBits  = 'BVBits     -- ^ @:: 'BVFlavor'@

type SemiRingInteger = 'SemiRingInteger   -- ^ @:: 'SemiRing'@
type SemiRingReal = 'SemiRingReal         -- ^ @:: 'SemiRing'@
type SemiRingBV = 'SemiRingBV             -- ^ @:: 'BVFlavor' -> 'Nat' -> 'SemiRing'@

data BVFlavorRepr (fv :: BVFlavor) where
  BVArithRepr :: BVFlavorRepr BVArith
  BVBitsRepr  :: BVFlavorRepr BVBits

data SemiRingRepr (sr :: SemiRing) where
  SemiRingIntegerRepr :: SemiRingRepr SemiRingInteger
  SemiRingRealRepr    :: SemiRingRepr SemiRingReal
  SemiRingBVRepr      :: (1 <= w) => !(BVFlavorRepr fv) -> !(NatRepr w) -> SemiRingRepr (SemiRingBV fv w)

-- | The subset of semirings that are equipped with an appropriate (order-respecting) total order.
data OrderedSemiRingRepr (sr :: SemiRing) where
  OrderedSemiRingIntegerRepr :: OrderedSemiRingRepr SemiRingInteger
  OrderedSemiRingRealRepr    :: OrderedSemiRingRepr SemiRingReal

-- | Compute the base type of the given semiring.
semiRingBase :: SemiRingRepr sr -> BaseTypeRepr (SemiRingBase sr)
semiRingBase :: forall (sr :: SemiRing).
SemiRingRepr sr -> BaseTypeRepr (SemiRingBase sr)
semiRingBase SemiRingRepr sr
SemiRingIntegerRepr = BaseTypeRepr 'BaseIntegerType
BaseIntegerRepr
semiRingBase SemiRingRepr sr
SemiRingRealRepr    = BaseTypeRepr 'BaseRealType
BaseRealRepr
semiRingBase (SemiRingBVRepr BVFlavorRepr fv
_fv NatRepr w
w)  = forall (w :: Natural).
(1 <= w) =>
NatRepr w -> BaseTypeRepr ('BaseBVType w)
BaseBVRepr NatRepr w
w

-- | Compute the semiring corresponding to the given ordered semiring.
orderedSemiRing :: OrderedSemiRingRepr sr -> SemiRingRepr sr
orderedSemiRing :: forall (sr :: SemiRing). OrderedSemiRingRepr sr -> SemiRingRepr sr
orderedSemiRing OrderedSemiRingRepr sr
OrderedSemiRingIntegerRepr = SemiRingRepr 'SemiRingInteger
SemiRingIntegerRepr
orderedSemiRing OrderedSemiRingRepr sr
OrderedSemiRingRealRepr    = SemiRingRepr 'SemiRingReal
SemiRingRealRepr

type family SemiRingBase (sr :: SemiRing) :: BaseType where
  SemiRingBase SemiRingInteger   = BaseIntegerType
  SemiRingBase SemiRingReal      = BaseRealType
  SemiRingBase (SemiRingBV fv w) = BaseBVType w

-- | The constant values in the semiring.
type family Coefficient (sr :: SemiRing) :: Type where
  Coefficient SemiRingInteger    = Integer
  Coefficient SemiRingReal       = Rational
  Coefficient (SemiRingBV fv w)  = BV.BV w

-- | The 'Occurrence' family counts how many times a term occurs in a
--   product. For most semirings, this is just a natural number
--   representing the exponent. For the boolean ring of bitvectors,
--   however, it is unit because the lattice operations are
--   idempotent.
type family Occurrence (sr :: SemiRing) :: Type where
  Occurrence SemiRingInteger        = Natural
  Occurrence SemiRingReal           = Natural
  Occurrence (SemiRingBV BVArith w) = Natural
  Occurrence (SemiRingBV BVBits w)  = ()

sr_compare :: SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Ordering
sr_compare :: forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Ordering
sr_compare SemiRingRepr sr
SemiRingIntegerRepr  = forall a. Ord a => a -> a -> Ordering
compare
sr_compare SemiRingRepr sr
SemiRingRealRepr     = forall a. Ord a => a -> a -> Ordering
compare
sr_compare (SemiRingBVRepr BVFlavorRepr fv
_ NatRepr w
_) = forall a. Ord a => a -> a -> Ordering
compare

sr_hashWithSalt :: SemiRingRepr sr -> Int -> Coefficient sr -> Int
sr_hashWithSalt :: forall (sr :: SemiRing).
SemiRingRepr sr -> Int -> Coefficient sr -> Int
sr_hashWithSalt SemiRingRepr sr
SemiRingIntegerRepr  = forall a. Hashable a => Int -> a -> Int
hashWithSalt
sr_hashWithSalt SemiRingRepr sr
SemiRingRealRepr     = forall a. Hashable a => Int -> a -> Int
hashWithSalt
sr_hashWithSalt (SemiRingBVRepr BVFlavorRepr fv
_ NatRepr w
_) = forall a. Hashable a => Int -> a -> Int
hashWithSalt

occ_one :: SemiRingRepr sr -> Occurrence sr
occ_one :: forall (sr :: SemiRing). SemiRingRepr sr -> Occurrence sr
occ_one SemiRingRepr sr
SemiRingIntegerRepr = Natural
1
occ_one SemiRingRepr sr
SemiRingRealRepr    = Natural
1
occ_one (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = Natural
1
occ_one (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = ()

occ_add :: SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Occurrence sr
occ_add :: forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Occurrence sr
occ_add SemiRingRepr sr
SemiRingIntegerRepr = forall a. Num a => a -> a -> a
(+)
occ_add SemiRingRepr sr
SemiRingRealRepr    = forall a. Num a => a -> a -> a
(+)
occ_add (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = forall a. Num a => a -> a -> a
(+)
occ_add (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = \Occurrence sr
_ Occurrence sr
_ -> ()

occ_count :: SemiRingRepr sr -> Occurrence sr -> Natural
occ_count :: forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Natural
occ_count SemiRingRepr sr
SemiRingIntegerRepr = forall a. a -> a
id
occ_count SemiRingRepr sr
SemiRingRealRepr    = forall a. a -> a
id
occ_count (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = forall a. a -> a
id
occ_count (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = \Occurrence sr
_ -> Natural
1

occ_eq :: SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Bool
occ_eq :: forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Bool
occ_eq SemiRingRepr sr
SemiRingIntegerRepr = forall a. Eq a => a -> a -> Bool
(==)
occ_eq SemiRingRepr sr
SemiRingRealRepr    = forall a. Eq a => a -> a -> Bool
(==)
occ_eq (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = forall a. Eq a => a -> a -> Bool
(==)
occ_eq (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = \Occurrence sr
_ Occurrence sr
_ -> Bool
True

occ_hashWithSalt :: SemiRingRepr sr -> Int -> Occurrence sr -> Int
occ_hashWithSalt :: forall (sr :: SemiRing).
SemiRingRepr sr -> Int -> Occurrence sr -> Int
occ_hashWithSalt SemiRingRepr sr
SemiRingIntegerRepr  = forall a. Hashable a => Int -> a -> Int
hashWithSalt
occ_hashWithSalt SemiRingRepr sr
SemiRingRealRepr     = forall a. Hashable a => Int -> a -> Int
hashWithSalt
occ_hashWithSalt (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = forall a. Hashable a => Int -> a -> Int
hashWithSalt
occ_hashWithSalt (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_) = forall a. Hashable a => Int -> a -> Int
hashWithSalt

occ_compare :: SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Ordering
occ_compare :: forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Ordering
occ_compare SemiRingRepr sr
SemiRingIntegerRepr  = forall a. Ord a => a -> a -> Ordering
compare
occ_compare SemiRingRepr sr
SemiRingRealRepr     = forall a. Ord a => a -> a -> Ordering
compare
occ_compare (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = forall a. Ord a => a -> a -> Ordering
compare
occ_compare (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = forall a. Ord a => a -> a -> Ordering
compare

zero :: SemiRingRepr sr -> Coefficient sr
zero :: forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
zero SemiRingRepr sr
SemiRingIntegerRepr      = Integer
0 :: Integer
zero SemiRingRepr sr
SemiRingRealRepr         = Rational
0 :: Rational
zero (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
w) = forall (w :: Natural). NatRepr w -> BV w
BV.zero NatRepr w
w
zero (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
w)  = forall (w :: Natural). NatRepr w -> BV w
BV.zero NatRepr w
w

one :: SemiRingRepr sr -> Coefficient sr
one :: forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
one SemiRingRepr sr
SemiRingIntegerRepr          = Integer
1 :: Integer
one SemiRingRepr sr
SemiRingRealRepr             = Rational
1 :: Rational
one (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
w) = forall (w :: Natural). NatRepr w -> Integer -> BV w
BV.mkBV NatRepr w
w Integer
1
one (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
w)  = forall (w :: Natural). NatRepr w -> BV w
BV.maxUnsigned NatRepr w
w

add :: SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Coefficient sr
add :: forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
add SemiRingRepr sr
SemiRingIntegerRepr      = forall a. Num a => a -> a -> a
(+)
add SemiRingRepr sr
SemiRingRealRepr         = forall a. Num a => a -> a -> a
(+)
add (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
w) = forall (w :: Natural). NatRepr w -> BV w -> BV w -> BV w
BV.add NatRepr w
w
add (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = forall (w :: Natural). BV w -> BV w -> BV w
BV.xor

mul :: SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Coefficient sr
mul :: forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
mul SemiRingRepr sr
SemiRingIntegerRepr      = forall a. Num a => a -> a -> a
(*)
mul SemiRingRepr sr
SemiRingRealRepr         = forall a. Num a => a -> a -> a
(*)
mul (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
w) = forall (w :: Natural). NatRepr w -> BV w -> BV w -> BV w
BV.mul NatRepr w
w
mul (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = forall (w :: Natural). BV w -> BV w -> BV w
BV.and

eq :: SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
eq :: forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
eq SemiRingRepr sr
SemiRingIntegerRepr      = forall a. Eq a => a -> a -> Bool
(==)
eq SemiRingRepr sr
SemiRingRealRepr         = forall a. Eq a => a -> a -> Bool
(==)
eq (SemiRingBVRepr BVFlavorRepr fv
_ NatRepr w
_)     = forall a. Eq a => a -> a -> Bool
(==)

le :: OrderedSemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
le :: forall (sr :: SemiRing).
OrderedSemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
le OrderedSemiRingRepr sr
OrderedSemiRingIntegerRepr = forall a. Ord a => a -> a -> Bool
(<=)
le OrderedSemiRingRepr sr
OrderedSemiRingRealRepr    = forall a. Ord a => a -> a -> Bool
(<=)

lt :: OrderedSemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
lt :: forall (sr :: SemiRing).
OrderedSemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
lt OrderedSemiRingRepr sr
OrderedSemiRingIntegerRepr = forall a. Ord a => a -> a -> Bool
(<)
lt OrderedSemiRingRepr sr
OrderedSemiRingRealRepr    = forall a. Ord a => a -> a -> Bool
(<)

$(return [])

instance TestEquality BVFlavorRepr where
  testEquality :: forall (a :: BVFlavor) (b :: BVFlavor).
BVFlavorRepr a -> BVFlavorRepr b -> Maybe (a :~: b)
testEquality = $(structuralTypeEquality [t|BVFlavorRepr|] [])
instance Eq (BVFlavorRepr fv) where
  BVFlavorRepr fv
x == :: BVFlavorRepr fv -> BVFlavorRepr fv -> Bool
== BVFlavorRepr fv
y = forall a. Maybe a -> Bool
isJust (forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality BVFlavorRepr fv
x BVFlavorRepr fv
y)

instance TestEquality OrderedSemiRingRepr where
  testEquality :: forall (a :: SemiRing) (b :: SemiRing).
OrderedSemiRingRepr a -> OrderedSemiRingRepr b -> Maybe (a :~: b)
testEquality = $(structuralTypeEquality [t|OrderedSemiRingRepr|] [])
instance Eq (OrderedSemiRingRepr sr) where
  OrderedSemiRingRepr sr
x == :: OrderedSemiRingRepr sr -> OrderedSemiRingRepr sr -> Bool
== OrderedSemiRingRepr sr
y = forall a. Maybe a -> Bool
isJust (forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality OrderedSemiRingRepr sr
x OrderedSemiRingRepr sr
y)

instance TestEquality SemiRingRepr where
  testEquality :: forall (a :: SemiRing) (b :: SemiRing).
SemiRingRepr a -> SemiRingRepr b -> Maybe (a :~: b)
testEquality =
    $(structuralTypeEquality [t|SemiRingRepr|]
      [ (ConType [t|NatRepr|] `TypeApp` AnyType, [|testEquality|])
      , (ConType [t|BVFlavorRepr|] `TypeApp` AnyType, [|testEquality|])
      ])
instance Eq (SemiRingRepr sr) where
  SemiRingRepr sr
x == :: SemiRingRepr sr -> SemiRingRepr sr -> Bool
== SemiRingRepr sr
y = forall a. Maybe a -> Bool
isJust (forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality SemiRingRepr sr
x SemiRingRepr sr
y)

instance OrdF BVFlavorRepr where
  compareF :: forall (x :: BVFlavor) (y :: BVFlavor).
BVFlavorRepr x -> BVFlavorRepr y -> OrderingF x y
compareF = $(structuralTypeOrd [t|BVFlavorRepr|] [])

instance OrdF OrderedSemiRingRepr where
  compareF :: forall (x :: SemiRing) (y :: SemiRing).
OrderedSemiRingRepr x -> OrderedSemiRingRepr y -> OrderingF x y
compareF = $(structuralTypeOrd [t|OrderedSemiRingRepr|] [])

instance OrdF SemiRingRepr where
  compareF :: forall (x :: SemiRing) (y :: SemiRing).
SemiRingRepr x -> SemiRingRepr y -> OrderingF x y
compareF =
    $(structuralTypeOrd [t|SemiRingRepr|]
      [ (ConType [t|NatRepr|] `TypeApp` AnyType, [|compareF|])
      , (ConType [t|BVFlavorRepr|] `TypeApp` AnyType, [|compareF|])
      ])

instance HashableF BVFlavorRepr where
  hashWithSaltF :: forall (tp :: BVFlavor). Int -> BVFlavorRepr tp -> Int
hashWithSaltF = $(structuralHashWithSalt [t|BVFlavorRepr|] [])
instance Hashable (BVFlavorRepr fv) where
  hashWithSalt :: Int -> BVFlavorRepr fv -> Int
hashWithSalt = forall k (f :: k -> Type) (tp :: k).
HashableF f =>
Int -> f tp -> Int
hashWithSaltF

instance HashableF OrderedSemiRingRepr where
  hashWithSaltF :: forall (tp :: SemiRing). Int -> OrderedSemiRingRepr tp -> Int
hashWithSaltF = $(structuralHashWithSalt [t|OrderedSemiRingRepr|] [])
instance Hashable (OrderedSemiRingRepr sr) where
  hashWithSalt :: Int -> OrderedSemiRingRepr sr -> Int
hashWithSalt = forall k (f :: k -> Type) (tp :: k).
HashableF f =>
Int -> f tp -> Int
hashWithSaltF

instance HashableF SemiRingRepr where
  hashWithSaltF :: forall (tp :: SemiRing). Int -> SemiRingRepr tp -> Int
hashWithSaltF = $(structuralHashWithSalt [t|SemiRingRepr|] [])
instance Hashable (SemiRingRepr sr) where
  hashWithSalt :: Int -> SemiRingRepr sr -> Int
hashWithSalt = forall k (f :: k -> Type) (tp :: k).
HashableF f =>
Int -> f tp -> Int
hashWithSaltF