{-|
Module      : What4.SemiRing
Description : Definitions related to semiring structures over base types.
Copyright   : (c) Galois Inc, 2019-2020
License     : BSD3
Maintainer  : rdockins@galois.com

The algebraic assumptions we make about our semirings are that:

* addition is commutative and associative, with a unit called zero,
* multiplication is commutative and associative, with a unit called one,
* one and zero are distinct values,
* multiplication distributes through addition, and
* multiplication by zero gives zero.

Note that we do not assume the existence of additive inverses (hence,
semirings), but we do assume commutativity of multiplication.

Note, moreover, that bitvectors can be equipped with two different
semirings (the usual arithmetic one and the XOR/AND boolean ring imposed
by the boolean algebra structure), which occasionally requires some care.

In addition, some semirings are "ordered" semirings.  These are equipped
with a total ordering relation such that addition is both order-preserving
and order-reflecting; that is, @x <= y@ iff @x + z <= y + z@.
Moreover ordered semirings satisfy: @0 <= x@ and @0 <= y@ implies @0 <= x*y@.
-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module What4.SemiRing
  ( -- * Semiring datakinds
    type SemiRing
  , type SemiRingInteger
  , type SemiRingReal
  , type SemiRingBV
  , type BVFlavor
  , type BVBits
  , type BVArith

    -- * Semiring representations
  , SemiRingRepr(..)
  , OrderedSemiRingRepr(..)
  , BVFlavorRepr(..)
  , SemiRingBase
  , semiRingBase
  , orderedSemiRing

    -- * Semiring coefficients
  , Coefficient
  , zero
  , one
  , add
  , mul
  , eq
  , le
  , lt
  , sr_compare
  , sr_hashWithSalt

    -- * Semiring product occurrences
  , Occurrence
  , occ_add
  , occ_one
  , occ_eq
  , occ_hashWithSalt
  , occ_compare
  , occ_count
  ) where

import GHC.TypeNats
import qualified Data.BitVector.Sized as BV
import Data.Kind
import Data.Hashable
import Data.Parameterized.Classes
import Data.Parameterized.TH.GADT
import Numeric.Natural

import What4.BaseTypes

-- | Data-kind indicating the two flavors of bitvector semirings.
--   The ordinary arithmetic semiring consists of addition and multiplication,
--   and the "bits" semiring consists of bitwise xor and bitwise and.
data BVFlavor = BVArith | BVBits

-- | Data-kind representing the semirings What4 supports.
data SemiRing
  = SemiRingInteger
  | SemiRingReal
  | SemiRingBV BVFlavor Nat

type BVArith = 'BVArith    -- ^ @:: 'BVFlavor'@
type BVBits  = 'BVBits     -- ^ @:: 'BVFlavor'@

type SemiRingInteger = 'SemiRingInteger   -- ^ @:: 'SemiRing'@
type SemiRingReal = 'SemiRingReal         -- ^ @:: 'SemiRing'@
type SemiRingBV = 'SemiRingBV             -- ^ @:: 'BVFlavor' -> 'Nat' -> 'SemiRing'@

data BVFlavorRepr (fv :: BVFlavor) where
  BVArithRepr :: BVFlavorRepr BVArith
  BVBitsRepr  :: BVFlavorRepr BVBits

data SemiRingRepr (sr :: SemiRing) where
  SemiRingIntegerRepr :: SemiRingRepr SemiRingInteger
  SemiRingRealRepr    :: SemiRingRepr SemiRingReal
  SemiRingBVRepr      :: (1 <= w) => !(BVFlavorRepr fv) -> !(NatRepr w) -> SemiRingRepr (SemiRingBV fv w)

-- | The subset of semirings that are equipped with an appropriate (order-respecting) total order.
data OrderedSemiRingRepr (sr :: SemiRing) where
  OrderedSemiRingIntegerRepr :: OrderedSemiRingRepr SemiRingInteger
  OrderedSemiRingRealRepr    :: OrderedSemiRingRepr SemiRingReal

-- | Compute the base type of the given semiring.
semiRingBase :: SemiRingRepr sr -> BaseTypeRepr (SemiRingBase sr)
semiRingBase :: SemiRingRepr sr -> BaseTypeRepr (SemiRingBase sr)
semiRingBase SemiRingRepr sr
SemiRingIntegerRepr = BaseTypeRepr BaseIntegerType
BaseTypeRepr (SemiRingBase sr)
BaseIntegerRepr
semiRingBase SemiRingRepr sr
SemiRingRealRepr    = BaseTypeRepr BaseRealType
BaseTypeRepr (SemiRingBase sr)
BaseRealRepr
semiRingBase (SemiRingBVRepr BVFlavorRepr fv
_fv NatRepr w
w)  = NatRepr w -> BaseTypeRepr (BaseBVType w)
forall (w :: Nat).
(1 <= w) =>
NatRepr w -> BaseTypeRepr (BaseBVType w)
BaseBVRepr NatRepr w
w

-- | Compute the semiring corresponding to the given ordered semiring.
orderedSemiRing :: OrderedSemiRingRepr sr -> SemiRingRepr sr
orderedSemiRing :: OrderedSemiRingRepr sr -> SemiRingRepr sr
orderedSemiRing OrderedSemiRingRepr sr
OrderedSemiRingIntegerRepr = SemiRingRepr sr
SemiRingRepr SemiRingInteger
SemiRingIntegerRepr
orderedSemiRing OrderedSemiRingRepr sr
OrderedSemiRingRealRepr    = SemiRingRepr sr
SemiRingRepr SemiRingReal
SemiRingRealRepr

type family SemiRingBase (sr :: SemiRing) :: BaseType where
  SemiRingBase SemiRingInteger   = BaseIntegerType
  SemiRingBase SemiRingReal      = BaseRealType
  SemiRingBase (SemiRingBV fv w) = BaseBVType w

-- | The constant values in the semiring.
type family Coefficient (sr :: SemiRing) :: Type where
  Coefficient SemiRingInteger    = Integer
  Coefficient SemiRingReal       = Rational
  Coefficient (SemiRingBV fv w)  = BV.BV w

-- | The 'Occurrence' family counts how many times a term occurs in a
--   product. For most semirings, this is just a natural number
--   representing the exponent. For the boolean ring of bitvectors,
--   however, it is unit because the lattice operations are
--   idempotent.
type family Occurrence (sr :: SemiRing) :: Type where
  Occurrence SemiRingInteger        = Natural
  Occurrence SemiRingReal           = Natural
  Occurrence (SemiRingBV BVArith w) = Natural
  Occurrence (SemiRingBV BVBits w)  = ()

sr_compare :: SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Ordering
sr_compare :: SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Ordering
sr_compare SemiRingRepr sr
SemiRingIntegerRepr  = Coefficient sr -> Coefficient sr -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
sr_compare SemiRingRepr sr
SemiRingRealRepr     = Coefficient sr -> Coefficient sr -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
sr_compare (SemiRingBVRepr BVFlavorRepr fv
_ NatRepr w
_) = Coefficient sr -> Coefficient sr -> Ordering
forall a. Ord a => a -> a -> Ordering
compare

sr_hashWithSalt :: SemiRingRepr sr -> Int -> Coefficient sr -> Int
sr_hashWithSalt :: SemiRingRepr sr -> Int -> Coefficient sr -> Int
sr_hashWithSalt SemiRingRepr sr
SemiRingIntegerRepr  = Int -> Coefficient sr -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt
sr_hashWithSalt SemiRingRepr sr
SemiRingRealRepr     = Int -> Coefficient sr -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt
sr_hashWithSalt (SemiRingBVRepr BVFlavorRepr fv
_ NatRepr w
_) = Int -> Coefficient sr -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt

occ_one :: SemiRingRepr sr -> Occurrence sr
occ_one :: SemiRingRepr sr -> Occurrence sr
occ_one SemiRingRepr sr
SemiRingIntegerRepr = Occurrence sr
1
occ_one SemiRingRepr sr
SemiRingRealRepr    = Occurrence sr
1
occ_one (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = Occurrence sr
1
occ_one (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = ()

occ_add :: SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Occurrence sr
occ_add :: SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Occurrence sr
occ_add SemiRingRepr sr
SemiRingIntegerRepr = Occurrence sr -> Occurrence sr -> Occurrence sr
forall a. Num a => a -> a -> a
(+)
occ_add SemiRingRepr sr
SemiRingRealRepr    = Occurrence sr -> Occurrence sr -> Occurrence sr
forall a. Num a => a -> a -> a
(+)
occ_add (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = Occurrence sr -> Occurrence sr -> Occurrence sr
forall a. Num a => a -> a -> a
(+)
occ_add (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = \Occurrence sr
_ Occurrence sr
_ -> ()

occ_count :: SemiRingRepr sr -> Occurrence sr -> Natural
occ_count :: SemiRingRepr sr -> Occurrence sr -> Natural
occ_count SemiRingRepr sr
SemiRingIntegerRepr = Occurrence sr -> Natural
forall a. a -> a
id
occ_count SemiRingRepr sr
SemiRingRealRepr    = Occurrence sr -> Natural
forall a. a -> a
id
occ_count (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = Occurrence sr -> Natural
forall a. a -> a
id
occ_count (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = \Occurrence sr
_ -> Natural
1

occ_eq :: SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Bool
occ_eq :: SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Bool
occ_eq SemiRingRepr sr
SemiRingIntegerRepr = Occurrence sr -> Occurrence sr -> Bool
forall a. Eq a => a -> a -> Bool
(==)
occ_eq SemiRingRepr sr
SemiRingRealRepr    = Occurrence sr -> Occurrence sr -> Bool
forall a. Eq a => a -> a -> Bool
(==)
occ_eq (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = Occurrence sr -> Occurrence sr -> Bool
forall a. Eq a => a -> a -> Bool
(==)
occ_eq (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = \Occurrence sr
_ Occurrence sr
_ -> Bool
True

occ_hashWithSalt :: SemiRingRepr sr -> Int -> Occurrence sr -> Int
occ_hashWithSalt :: SemiRingRepr sr -> Int -> Occurrence sr -> Int
occ_hashWithSalt SemiRingRepr sr
SemiRingIntegerRepr  = Int -> Occurrence sr -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt
occ_hashWithSalt SemiRingRepr sr
SemiRingRealRepr     = Int -> Occurrence sr -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt
occ_hashWithSalt (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = Int -> Occurrence sr -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt
occ_hashWithSalt (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_) = Int -> Occurrence sr -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt

occ_compare :: SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Ordering
occ_compare :: SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Ordering
occ_compare SemiRingRepr sr
SemiRingIntegerRepr  = Occurrence sr -> Occurrence sr -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
occ_compare SemiRingRepr sr
SemiRingRealRepr     = Occurrence sr -> Occurrence sr -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
occ_compare (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
_) = Occurrence sr -> Occurrence sr -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
occ_compare (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = Occurrence sr -> Occurrence sr -> Ordering
forall a. Ord a => a -> a -> Ordering
compare

zero :: SemiRingRepr sr -> Coefficient sr
zero :: SemiRingRepr sr -> Coefficient sr
zero SemiRingRepr sr
SemiRingIntegerRepr      = Integer
0 :: Integer
zero SemiRingRepr sr
SemiRingRealRepr         = Rational
0 :: Rational
zero (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
w) = NatRepr w -> BV w
forall (w :: Nat). NatRepr w -> BV w
BV.zero NatRepr w
w
zero (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
w)  = NatRepr w -> BV w
forall (w :: Nat). NatRepr w -> BV w
BV.zero NatRepr w
w

one :: SemiRingRepr sr -> Coefficient sr
one :: SemiRingRepr sr -> Coefficient sr
one SemiRingRepr sr
SemiRingIntegerRepr          = Integer
1 :: Integer
one SemiRingRepr sr
SemiRingRealRepr             = Rational
1 :: Rational
one (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
w) = NatRepr w -> Integer -> BV w
forall (w :: Nat). NatRepr w -> Integer -> BV w
BV.mkBV NatRepr w
w Integer
1
one (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
w)  = NatRepr w -> BV w
forall (w :: Nat). NatRepr w -> BV w
BV.maxUnsigned NatRepr w
w

add :: SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Coefficient sr
add :: SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
add SemiRingRepr sr
SemiRingIntegerRepr      = Coefficient sr -> Coefficient sr -> Coefficient sr
forall a. Num a => a -> a -> a
(+)
add SemiRingRepr sr
SemiRingRealRepr         = Coefficient sr -> Coefficient sr -> Coefficient sr
forall a. Num a => a -> a -> a
(+)
add (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
w) = NatRepr w -> BV w -> BV w -> BV w
forall (w :: Nat). NatRepr w -> BV w -> BV w -> BV w
BV.add NatRepr w
w
add (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = Coefficient sr -> Coefficient sr -> Coefficient sr
forall (w :: Nat). BV w -> BV w -> BV w
BV.xor

mul :: SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Coefficient sr
mul :: SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
mul SemiRingRepr sr
SemiRingIntegerRepr      = Coefficient sr -> Coefficient sr -> Coefficient sr
forall a. Num a => a -> a -> a
(*)
mul SemiRingRepr sr
SemiRingRealRepr         = Coefficient sr -> Coefficient sr -> Coefficient sr
forall a. Num a => a -> a -> a
(*)
mul (SemiRingBVRepr BVFlavorRepr fv
BVArithRepr NatRepr w
w) = NatRepr w -> BV w -> BV w -> BV w
forall (w :: Nat). NatRepr w -> BV w -> BV w -> BV w
BV.mul NatRepr w
w
mul (SemiRingBVRepr BVFlavorRepr fv
BVBitsRepr NatRepr w
_)  = Coefficient sr -> Coefficient sr -> Coefficient sr
forall (w :: Nat). BV w -> BV w -> BV w
BV.and

eq :: SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
eq :: SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
eq SemiRingRepr sr
SemiRingIntegerRepr      = Coefficient sr -> Coefficient sr -> Bool
forall a. Eq a => a -> a -> Bool
(==)
eq SemiRingRepr sr
SemiRingRealRepr         = Coefficient sr -> Coefficient sr -> Bool
forall a. Eq a => a -> a -> Bool
(==)
eq (SemiRingBVRepr BVFlavorRepr fv
_ NatRepr w
_)     = Coefficient sr -> Coefficient sr -> Bool
forall a. Eq a => a -> a -> Bool
(==)

le :: OrderedSemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
le :: OrderedSemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
le OrderedSemiRingRepr sr
OrderedSemiRingIntegerRepr = Coefficient sr -> Coefficient sr -> Bool
forall a. Ord a => a -> a -> Bool
(<=)
le OrderedSemiRingRepr sr
OrderedSemiRingRealRepr    = Coefficient sr -> Coefficient sr -> Bool
forall a. Ord a => a -> a -> Bool
(<=)

lt :: OrderedSemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
lt :: OrderedSemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
lt OrderedSemiRingRepr sr
OrderedSemiRingIntegerRepr = Coefficient sr -> Coefficient sr -> Bool
forall a. Ord a => a -> a -> Bool
(<)
lt OrderedSemiRingRepr sr
OrderedSemiRingRealRepr    = Coefficient sr -> Coefficient sr -> Bool
forall a. Ord a => a -> a -> Bool
(<)

$(return [])

instance TestEquality BVFlavorRepr where
  testEquality :: BVFlavorRepr a -> BVFlavorRepr b -> Maybe (a :~: b)
testEquality = $(structuralTypeEquality [t|BVFlavorRepr|] [])

instance TestEquality OrderedSemiRingRepr where
  testEquality :: OrderedSemiRingRepr a -> OrderedSemiRingRepr b -> Maybe (a :~: b)
testEquality = $(structuralTypeEquality [t|OrderedSemiRingRepr|] [])

instance TestEquality SemiRingRepr where
  testEquality :: SemiRingRepr a -> SemiRingRepr b -> Maybe (a :~: b)
testEquality =
    $(structuralTypeEquality [t|SemiRingRepr|]
      [ (ConType [t|NatRepr|] `TypeApp` AnyType, [|testEquality|])
      , (ConType [t|BVFlavorRepr|] `TypeApp` AnyType, [|testEquality|])
      ])

instance OrdF BVFlavorRepr where
  compareF :: BVFlavorRepr x -> BVFlavorRepr y -> OrderingF x y
compareF = $(structuralTypeOrd [t|BVFlavorRepr|] [])

instance OrdF OrderedSemiRingRepr where
  compareF :: OrderedSemiRingRepr x -> OrderedSemiRingRepr y -> OrderingF x y
compareF = $(structuralTypeOrd [t|OrderedSemiRingRepr|] [])

instance OrdF SemiRingRepr where
  compareF :: SemiRingRepr x -> SemiRingRepr y -> OrderingF x y
compareF =
    $(structuralTypeOrd [t|SemiRingRepr|]
      [ (ConType [t|NatRepr|] `TypeApp` AnyType, [|compareF|])
      , (ConType [t|BVFlavorRepr|] `TypeApp` AnyType, [|compareF|])
      ])

instance HashableF BVFlavorRepr where
  hashWithSaltF :: Int -> BVFlavorRepr tp -> Int
hashWithSaltF = $(structuralHashWithSalt [t|BVFlavorRepr|] [])
instance Hashable (BVFlavorRepr fv) where
  hashWithSalt :: Int -> BVFlavorRepr fv -> Int
hashWithSalt = Int -> BVFlavorRepr fv -> Int
forall k (f :: k -> Type) (tp :: k).
HashableF f =>
Int -> f tp -> Int
hashWithSaltF

instance HashableF OrderedSemiRingRepr where
  hashWithSaltF :: Int -> OrderedSemiRingRepr tp -> Int
hashWithSaltF = $(structuralHashWithSalt [t|OrderedSemiRingRepr|] [])
instance Hashable (OrderedSemiRingRepr sr) where
  hashWithSalt :: Int -> OrderedSemiRingRepr sr -> Int
hashWithSalt = Int -> OrderedSemiRingRepr sr -> Int
forall k (f :: k -> Type) (tp :: k).
HashableF f =>
Int -> f tp -> Int
hashWithSaltF

instance HashableF SemiRingRepr where
  hashWithSaltF :: Int -> SemiRingRepr tp -> Int
hashWithSaltF = $(structuralHashWithSalt [t|SemiRingRepr|] [])
instance Hashable (SemiRingRepr sr) where
  hashWithSalt :: Int -> SemiRingRepr sr -> Int
hashWithSalt = Int -> SemiRingRepr sr -> Int
forall k (f :: k -> Type) (tp :: k).
HashableF f =>
Int -> f tp -> Int
hashWithSaltF