{-|
Module      : What4.Expr.WeightedSum
Description : Representations for weighted sums and products in semirings
Copyright   : (c) Galois Inc, 2015-2020
License     : BSD3
Maintainer  : jhendrix@galois.com

Declares a weighted sum type used for representing sums over variables and an offset
in one of the supported semirings.  This module also implements a representation of
semiring products.
-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wwarn #-}
module What4.Expr.WeightedSum
  ( -- * Utilities
    Tm
    -- * Weighted sums
  , WeightedSum
  , sumRepr
  , sumOffset
  , sumAbsValue
  , constant
  , var
  , scaledVar
  , asConstant
  , asVar
  , asWeightedVar
  , asAffineVar
  , isZero
  , traverseVars
  , traverseCoeffs
  , add
  , addVar
  , addVars
  , addConstant
  , scale
  , eval
  , evalM
  , extractCommon
  , fromTerms
  , transformSum
  , reduceIntSumMod

    -- * Ring products
  , SemiRingProduct
  , traverseProdVars
  , nullProd
  , asProdVar
  , prodRepr
  , prodVar
  , prodAbsValue
  , prodMul
  , prodEval
  , prodEvalM
  , prodContains
  ) where

import           Control.Lens
import           Control.Monad (unless)
import qualified Data.BitVector.Sized as BV
import           Data.Hashable
import           Data.Kind
import           Data.List (foldl')
import           Data.Maybe
import           Data.Parameterized.Classes

import           What4.BaseTypes
import qualified What4.SemiRing as SR
import           What4.Utils.AnnotatedMap (AnnotatedMap)
import qualified What4.Utils.AnnotatedMap as AM
import qualified What4.Utils.AbstractDomains as AD
import qualified What4.Utils.BVDomain.Arith as A
import qualified What4.Utils.BVDomain.XOR as X
import qualified What4.Utils.BVDomain as BVD

import           What4.Utils.IncrHash

--------------------------------------------------------------------------------

data SRAbsValue :: SR.SemiRing -> Type where
  SRAbsIntAdd  :: !(AD.ValueRange Integer)  -> SRAbsValue SR.SemiRingInteger
  SRAbsRealAdd :: !AD.RealAbstractValue     -> SRAbsValue SR.SemiRingReal
  SRAbsBVAdd   :: (1 <= w) => !(A.Domain w) -> SRAbsValue (SR.SemiRingBV SR.BVArith w)
  SRAbsBVXor   :: (1 <= w) => !(X.Domain w) -> SRAbsValue (SR.SemiRingBV SR.BVBits w)

instance Semigroup (SRAbsValue sr) where
  SRAbsIntAdd  ValueRange Integer
x <> :: SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
<> SRAbsIntAdd  ValueRange Integer
y = ValueRange Integer -> SRAbsValue 'SemiRingInteger
SRAbsIntAdd  (forall tp.
Num tp =>
ValueRange tp -> ValueRange tp -> ValueRange tp
AD.addRange ValueRange Integer
x ValueRange Integer
y)
  SRAbsRealAdd RealAbstractValue
x <> SRAbsRealAdd RealAbstractValue
y = RealAbstractValue -> SRAbsValue 'SemiRingReal
SRAbsRealAdd (RealAbstractValue -> RealAbstractValue -> RealAbstractValue
AD.ravAdd RealAbstractValue
x RealAbstractValue
y)
  SRAbsBVAdd   Domain w
x <> SRAbsBVAdd   Domain w
y = forall (w :: Natural).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVArith w)
SRAbsBVAdd   (forall (w :: Natural). (1 <= w) => Domain w -> Domain w -> Domain w
A.add Domain w
x Domain w
y)
  SRAbsBVXor   Domain w
x <> SRAbsBVXor   Domain w
y = forall (w :: Natural).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVBits w)
SRAbsBVXor   (forall (w :: Natural). Domain w -> Domain w -> Domain w
X.xor Domain w
x Domain w
y)


(.**) :: SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
SRAbsIntAdd  ValueRange Integer
x .** :: forall (sr :: SemiRing).
SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
.** SRAbsIntAdd  ValueRange Integer
y = ValueRange Integer -> SRAbsValue 'SemiRingInteger
SRAbsIntAdd  (forall tp.
(Ord tp, Num tp) =>
ValueRange tp -> ValueRange tp -> ValueRange tp
AD.mulRange ValueRange Integer
x ValueRange Integer
y)
SRAbsRealAdd RealAbstractValue
x .** SRAbsRealAdd RealAbstractValue
y = RealAbstractValue -> SRAbsValue 'SemiRingReal
SRAbsRealAdd (RealAbstractValue -> RealAbstractValue -> RealAbstractValue
AD.ravMul RealAbstractValue
x RealAbstractValue
y)
SRAbsBVAdd   Domain w
x .** SRAbsBVAdd   Domain w
y = forall (w :: Natural).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVArith w)
SRAbsBVAdd   (forall (w :: Natural). (1 <= w) => Domain w -> Domain w -> Domain w
A.mul Domain w
x Domain w
y)
SRAbsBVXor   Domain w
x .** SRAbsBVXor   Domain w
y = forall (w :: Natural).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVBits w)
SRAbsBVXor   (forall (w :: Natural). Domain w -> Domain w -> Domain w
X.and Domain w
x Domain w
y)

abstractTerm ::
  AD.HasAbsValue f =>
  SR.SemiRingRepr sr -> SR.Coefficient sr -> f (SR.SemiRingBase sr) -> SRAbsValue sr
abstractTerm :: forall (f :: BaseType -> Type) (sr :: SemiRing).
HasAbsValue f =>
SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SRAbsValue sr
abstractTerm SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
e =
  case SemiRingRepr sr
sr of
    SemiRingRepr sr
SR.SemiRingIntegerRepr -> ValueRange Integer -> SRAbsValue 'SemiRingInteger
SRAbsIntAdd (forall tp. (Ord tp, Num tp) => tp -> ValueRange tp -> ValueRange tp
AD.rangeScalarMul Coefficient sr
c (forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (SemiRingBase sr)
e))
    SemiRingRepr sr
SR.SemiRingRealRepr    -> RealAbstractValue -> SRAbsValue 'SemiRingReal
SRAbsRealAdd (Rational -> RealAbstractValue -> RealAbstractValue
AD.ravScalarMul Coefficient sr
c (forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (SemiRingBase sr)
e))
    SR.SemiRingBVRepr BVFlavorRepr fv
fv NatRepr w
w ->
      case BVFlavorRepr fv
fv of
        BVFlavorRepr fv
SR.BVArithRepr ->
          -- A.scale expects a signed integer coefficient
          forall (w :: Natural).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVArith w)
SRAbsBVAdd (forall (w :: Natural). (1 <= w) => Integer -> Domain w -> Domain w
A.scale (forall (w :: Natural). (1 <= w) => NatRepr w -> BV w -> Integer
BV.asSigned NatRepr w
w Coefficient sr
c) (forall (w :: Natural). BVDomain w -> Domain w
BVD.asArithDomain (forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (SemiRingBase sr)
e)))
        BVFlavorRepr fv
SR.BVBitsRepr  -> forall (w :: Natural).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVBits w)
SRAbsBVXor (forall (w :: Natural). Integer -> Domain w -> Domain w
X.and_scalar (forall (w :: Natural). BV w -> Integer
BV.asUnsigned Coefficient sr
c) (forall (w :: Natural). BVDomain w -> Domain w
BVD.asXorDomain (forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (SemiRingBase sr)
e)))

abstractVal :: AD.HasAbsValue f => SR.SemiRingRepr sr -> f (SR.SemiRingBase sr) -> SRAbsValue sr
abstractVal :: forall (f :: BaseType -> Type) (sr :: SemiRing).
HasAbsValue f =>
SemiRingRepr sr -> f (SemiRingBase sr) -> SRAbsValue sr
abstractVal SemiRingRepr sr
sr f (SemiRingBase sr)
e =
  case SemiRingRepr sr
sr of
    SemiRingRepr sr
SR.SemiRingIntegerRepr -> ValueRange Integer -> SRAbsValue 'SemiRingInteger
SRAbsIntAdd (forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (SemiRingBase sr)
e)
    SemiRingRepr sr
SR.SemiRingRealRepr    -> RealAbstractValue -> SRAbsValue 'SemiRingReal
SRAbsRealAdd (forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (SemiRingBase sr)
e)
    SR.SemiRingBVRepr BVFlavorRepr fv
fv NatRepr w
_w ->
      case BVFlavorRepr fv
fv of
        BVFlavorRepr fv
SR.BVArithRepr -> forall (w :: Natural).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVArith w)
SRAbsBVAdd (forall (w :: Natural). BVDomain w -> Domain w
BVD.asArithDomain (forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (SemiRingBase sr)
e))
        BVFlavorRepr fv
SR.BVBitsRepr  -> forall (w :: Natural).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVBits w)
SRAbsBVXor (forall (w :: Natural). BVDomain w -> Domain w
BVD.asXorDomain (forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (SemiRingBase sr)
e))

abstractScalar ::
  SR.SemiRingRepr sr -> SR.Coefficient sr -> SRAbsValue sr
abstractScalar :: forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> SRAbsValue sr
abstractScalar SemiRingRepr sr
sr Coefficient sr
c =
  case SemiRingRepr sr
sr of
    SemiRingRepr sr
SR.SemiRingIntegerRepr -> ValueRange Integer -> SRAbsValue 'SemiRingInteger
SRAbsIntAdd (forall tp. tp -> ValueRange tp
AD.SingleRange Coefficient sr
c)
    SemiRingRepr sr
SR.SemiRingRealRepr    -> RealAbstractValue -> SRAbsValue 'SemiRingReal
SRAbsRealAdd (Rational -> RealAbstractValue
AD.ravSingle Coefficient sr
c)
    SR.SemiRingBVRepr BVFlavorRepr fv
fv NatRepr w
w ->
      case BVFlavorRepr fv
fv of
        BVFlavorRepr fv
SR.BVArithRepr -> forall (w :: Natural).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVArith w)
SRAbsBVAdd (forall (w :: Natural).
(HasCallStack, 1 <= w) =>
NatRepr w -> Integer -> Domain w
A.singleton NatRepr w
w (forall (w :: Natural). BV w -> Integer
BV.asUnsigned Coefficient sr
c))
        BVFlavorRepr fv
SR.BVBitsRepr  -> forall (w :: Natural).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVBits w)
SRAbsBVXor (forall (w :: Natural). NatRepr w -> Integer -> Domain w
X.singleton NatRepr w
w (forall (w :: Natural). BV w -> Integer
BV.asUnsigned Coefficient sr
c))

fromSRAbsValue ::
  SRAbsValue sr -> AD.AbstractValue (SR.SemiRingBase sr)
fromSRAbsValue :: forall (sr :: SemiRing).
SRAbsValue sr -> AbstractValue (SemiRingBase sr)
fromSRAbsValue SRAbsValue sr
v =
  case SRAbsValue sr
v of
    SRAbsIntAdd  ValueRange Integer
x -> ValueRange Integer
x
    SRAbsRealAdd RealAbstractValue
x -> RealAbstractValue
x
    SRAbsBVAdd   Domain w
x -> forall (w :: Natural). Domain w -> BVDomain w
BVD.BVDArith Domain w
x
    SRAbsBVXor   Domain w
x -> forall (w :: Natural). Domain w -> BVDomain w
BVD.fromXorDomain Domain w
x

--------------------------------------------------------------------------------

type Tm f = (HashableF f, OrdF f, AD.HasAbsValue f)

newtype WrapF (f :: BaseType -> Type) (i :: SR.SemiRing) = WrapF (f (SR.SemiRingBase i))

instance OrdF f => Ord (WrapF f i) where
  compare :: WrapF f i -> WrapF f i -> Ordering
compare (WrapF f (SemiRingBase i)
x) (WrapF f (SemiRingBase i)
y) = forall {k} (x :: k) (y :: k). OrderingF x y -> Ordering
toOrdering forall a b. (a -> b) -> a -> b
$ forall k (ktp :: k -> Type) (x :: k) (y :: k).
OrdF ktp =>
ktp x -> ktp y -> OrderingF x y
compareF f (SemiRingBase i)
x f (SemiRingBase i)
y

instance TestEquality f => Eq (WrapF f i) where
  (WrapF f (SemiRingBase i)
x) == :: WrapF f i -> WrapF f i -> Bool
== (WrapF f (SemiRingBase i)
y) = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality f (SemiRingBase i)
x f (SemiRingBase i)
y

instance (HashableF f, TestEquality f) => Hashable (WrapF f i) where
  hashWithSalt :: Int -> WrapF f i -> Int
hashWithSalt Int
s (WrapF f (SemiRingBase i)
x) = forall k (f :: k -> Type) (tp :: k).
HashableF f =>
Int -> f tp -> Int
hashWithSaltF Int
s f (SemiRingBase i)
x

traverseWrap :: Functor m => (f (SR.SemiRingBase i) -> m (g (SR.SemiRingBase i))) -> WrapF f i -> m (WrapF g i)
traverseWrap :: forall (m :: Type -> Type) (f :: BaseType -> Type) (i :: SemiRing)
       (g :: BaseType -> Type).
Functor m =>
(f (SemiRingBase i) -> m (g (SemiRingBase i)))
-> WrapF f i -> m (WrapF g i)
traverseWrap f (SemiRingBase i) -> m (g (SemiRingBase i))
f (WrapF f (SemiRingBase i)
x) = forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> f (SemiRingBase i) -> m (g (SemiRingBase i))
f f (SemiRingBase i)
x

-- | The annotation type used for the annotated map. It consists of
-- the hash value and the abstract domain representation of type @d@
-- for each submap.
data Note sr = Note !IncrHash !(SRAbsValue sr)

instance Semigroup (Note sr) where
  Note IncrHash
h1 SRAbsValue sr
d1 <> :: Note sr -> Note sr -> Note sr
<> Note IncrHash
h2 SRAbsValue sr
d2 = forall (sr :: SemiRing). IncrHash -> SRAbsValue sr -> Note sr
Note (IncrHash
h1 forall a. Semigroup a => a -> a -> a
<> IncrHash
h2) (SRAbsValue sr
d1 forall a. Semigroup a => a -> a -> a
<> SRAbsValue sr
d2)

data ProdNote sr = ProdNote !IncrHash !(SRAbsValue sr)

-- | The annotation type used for the annotated map for products.
-- It consists of the hash value and the abstract domain representation
-- of type @d@ for each submap.  NOTE! that the multiplication operation
-- on abstract values is not always associative.  This, however, is
-- acceptable because all associative groupings lead to sound (but perhaps not best)
-- approximate values.

instance Semigroup (ProdNote sr) where
  ProdNote IncrHash
h1 SRAbsValue sr
d1 <> :: ProdNote sr -> ProdNote sr -> ProdNote sr
<> ProdNote IncrHash
h2 SRAbsValue sr
d2 = forall (sr :: SemiRing). IncrHash -> SRAbsValue sr -> ProdNote sr
ProdNote (IncrHash
h1 forall a. Semigroup a => a -> a -> a
<> IncrHash
h2) (SRAbsValue sr
d1 forall (sr :: SemiRing).
SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
.** SRAbsValue sr
d2)

-- | Construct the annotation for a single map entry.
mkNote ::
  (HashableF f, AD.HasAbsValue f) =>
  SR.SemiRingRepr sr -> SR.Coefficient sr -> f (SR.SemiRingBase sr) -> Note sr
mkNote :: forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t = forall (sr :: SemiRing). IncrHash -> SRAbsValue sr -> Note sr
Note (Int -> IncrHash
mkIncrHash Int
h) SRAbsValue sr
d
  where
    h :: Int
h = forall (sr :: SemiRing).
SemiRingRepr sr -> Int -> Coefficient sr -> Int
SR.sr_hashWithSalt SemiRingRepr sr
sr (forall k (f :: k -> Type) (tp :: k). HashableF f => f tp -> Int
hashF f (SemiRingBase sr)
t) Coefficient sr
c
    d :: SRAbsValue sr
d = forall (f :: BaseType -> Type) (sr :: SemiRing).
HasAbsValue f =>
SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SRAbsValue sr
abstractTerm SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t

mkProdNote ::
  (HashableF f, AD.HasAbsValue f) =>
  SR.SemiRingRepr sr ->
  SR.Occurrence sr ->
  f (SR.SemiRingBase sr) ->
  ProdNote sr
mkProdNote :: forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdNote sr
mkProdNote SemiRingRepr sr
sr Occurrence sr
occ f (SemiRingBase sr)
t = forall (sr :: SemiRing). IncrHash -> SRAbsValue sr -> ProdNote sr
ProdNote (Int -> IncrHash
mkIncrHash Int
h) SRAbsValue sr
d
  where
    h :: Int
h = forall (sr :: SemiRing).
SemiRingRepr sr -> Int -> Occurrence sr -> Int
SR.occ_hashWithSalt SemiRingRepr sr
sr (forall k (f :: k -> Type) (tp :: k). HashableF f => f tp -> Int
hashF f (SemiRingBase sr)
t) Occurrence sr
occ
    v :: SRAbsValue sr
v = forall (f :: BaseType -> Type) (sr :: SemiRing).
HasAbsValue f =>
SemiRingRepr sr -> f (SemiRingBase sr) -> SRAbsValue sr
abstractVal SemiRingRepr sr
sr f (SemiRingBase sr)
t
    power :: Integer
power = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Natural
SR.occ_count SemiRingRepr sr
sr Occurrence sr
occ)
    d :: SRAbsValue sr
d = Integer -> SRAbsValue sr -> SRAbsValue sr
go (Integer
power forall a. Num a => a -> a -> a
- Integer
1) SRAbsValue sr
v

    go :: Integer -> SRAbsValue sr -> SRAbsValue sr
go (Integer
n::Integer) SRAbsValue sr
x
      | Integer
n forall a. Ord a => a -> a -> Bool
> Integer
0     = Integer -> SRAbsValue sr -> SRAbsValue sr
go (Integer
nforall a. Num a => a -> a -> a
-Integer
1) (SRAbsValue sr
v forall (sr :: SemiRing).
SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
.** SRAbsValue sr
x)
      | Bool
otherwise = SRAbsValue sr
x

type SumMap f sr  = AnnotatedMap (WrapF f sr) (Note sr) (SR.Coefficient sr)
type ProdMap f sr = AnnotatedMap (WrapF f sr) (ProdNote sr) (SR.Occurrence sr)

insertSumMap ::
  Tm f =>
  SR.SemiRingRepr sr ->
  SR.Coefficient sr -> f (SR.SemiRingBase sr) -> SumMap f sr -> SumMap f sr
insertSumMap :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr
-> f (SemiRingBase sr)
-> SumMap f sr
-> SumMap f sr
insertSumMap SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t = forall k v a.
(Ord k, Semigroup v) =>
(Maybe (v, a) -> Maybe (v, a))
-> k -> AnnotatedMap k v a -> AnnotatedMap k v a
AM.alter Maybe (Note sr, Coefficient sr) -> Maybe (Note sr, Coefficient sr)
f (forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF f (SemiRingBase sr)
t)
  where
    f :: Maybe (Note sr, Coefficient sr) -> Maybe (Note sr, Coefficient sr)
f Maybe (Note sr, Coefficient sr)
Nothing = forall a. a -> Maybe a
Just (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t, Coefficient sr
c)
    f (Just (Note sr
_, Coefficient sr
c0))
      | forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) Coefficient sr
c' = forall a. Maybe a
Nothing
      | Bool
otherwise = forall a. a -> Maybe a
Just (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
c' f (SemiRingBase sr)
t, Coefficient sr
c')
      where c' :: Coefficient sr
c' = forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.add SemiRingRepr sr
sr Coefficient sr
c0 Coefficient sr
c

singletonSumMap ::
  Tm f =>
  SR.SemiRingRepr sr ->
  SR.Coefficient sr -> f (SR.SemiRingBase sr) -> SumMap f sr
singletonSumMap :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SumMap f sr
singletonSumMap SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t = forall k v a.
(Ord k, Semigroup v) =>
k -> v -> a -> AnnotatedMap k v a
AM.singleton (forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF f (SemiRingBase sr)
t) (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t) Coefficient sr
c

singletonProdMap ::
  Tm f =>
  SR.SemiRingRepr sr ->
  SR.Occurrence sr ->
  f (SR.SemiRingBase sr) ->
  ProdMap f sr
singletonProdMap :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdMap f sr
singletonProdMap SemiRingRepr sr
sr Occurrence sr
occ f (SemiRingBase sr)
t = forall k v a.
(Ord k, Semigroup v) =>
k -> v -> a -> AnnotatedMap k v a
AM.singleton (forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF f (SemiRingBase sr)
t) (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdNote sr
mkProdNote SemiRingRepr sr
sr Occurrence sr
occ f (SemiRingBase sr)
t) Occurrence sr
occ

fromListSumMap ::
  Tm f =>
  SR.SemiRingRepr sr ->
  [(f (SR.SemiRingBase sr), SR.Coefficient sr)] -> SumMap f sr
fromListSumMap :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)] -> SumMap f sr
fromListSumMap SemiRingRepr sr
_ [] = forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty
fromListSumMap SemiRingRepr sr
sr ((f (SemiRingBase sr)
t, Coefficient sr
c) : [(f (SemiRingBase sr), Coefficient sr)]
xs) = forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr
-> f (SemiRingBase sr)
-> SumMap f sr
-> SumMap f sr
insertSumMap SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t (forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)] -> SumMap f sr
fromListSumMap SemiRingRepr sr
sr [(f (SemiRingBase sr), Coefficient sr)]
xs)

toListSumMap :: SumMap f sr -> [(f (SR.SemiRingBase sr), SR.Coefficient sr)]
toListSumMap :: forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap SumMap f sr
am = [ (f (SemiRingBase sr)
t, Coefficient sr
c) | (WrapF f (SemiRingBase sr)
t, Coefficient sr
c) <- forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList SumMap f sr
am ]

-- | A weighted sum of semiring values.  Mathematically, this represents
--   an affine operation on the underlying expressions.
data WeightedSum (f :: BaseType -> Type) (sr :: SR.SemiRing)
   = WeightedSum { forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap     :: !(SumMap f sr)
                 , forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset  :: !(SR.Coefficient sr)
                 , forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr     :: !(SR.SemiRingRepr sr)
                     -- ^ Runtime representation of the semiring for this sum.
                 }

-- | A product of semiring values.
data SemiRingProduct (f :: BaseType -> Type) (sr :: SR.SemiRing)
   = SemiRingProduct { forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap  :: !(ProdMap f sr)
                     , forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr  :: !(SR.SemiRingRepr sr)
                         -- ^ Runtime representation of the semiring for this product
                     }

-- | Return the hash of the 'SumMap' part of the 'WeightedSum'.
sumMapHash :: OrdF f => WeightedSum f sr -> IncrHash
sumMapHash :: forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
WeightedSum f sr -> IncrHash
sumMapHash WeightedSum f sr
x =
  case forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a -> Maybe v
AM.annotation (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
x) of
    Maybe (Note sr)
Nothing -> forall a. Monoid a => a
mempty
    Just (Note IncrHash
h SRAbsValue sr
_) -> IncrHash
h

prodMapHash :: OrdF f => SemiRingProduct f sr -> IncrHash
prodMapHash :: forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
SemiRingProduct f sr -> IncrHash
prodMapHash SemiRingProduct f sr
pd =
  case forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a -> Maybe v
AM.annotation (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
pd) of
    Maybe (ProdNote sr)
Nothing -> forall a. Monoid a => a
mempty
    Just (ProdNote IncrHash
h SRAbsValue sr
_) -> IncrHash
h

sumAbsValue :: OrdF f => WeightedSum f sr -> AD.AbstractValue (SR.SemiRingBase sr)
sumAbsValue :: forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
WeightedSum f sr -> AbstractValue (SemiRingBase sr)
sumAbsValue WeightedSum f sr
wsum =
  forall (sr :: SemiRing).
SRAbsValue sr -> AbstractValue (SemiRingBase sr)
fromSRAbsValue forall a b. (a -> b) -> a -> b
$
  case forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a -> Maybe v
AM.annotation (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
wsum) of
    Maybe (Note sr)
Nothing         -> SRAbsValue sr
absOffset
    Just (Note IncrHash
_ SRAbsValue sr
v) -> SRAbsValue sr
absOffset forall a. Semigroup a => a -> a -> a
<> SRAbsValue sr
v
  where
    absOffset :: SRAbsValue sr
absOffset = forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> SRAbsValue sr
abstractScalar (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
wsum) (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
wsum)

instance OrdF f => TestEquality (SemiRingProduct f) where
  testEquality :: forall (a :: SemiRing) (b :: SemiRing).
SemiRingProduct f a -> SemiRingProduct f b -> Maybe (a :~: b)
testEquality SemiRingProduct f a
x SemiRingProduct f b
y
    | forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
SemiRingProduct f sr -> IncrHash
prodMapHash SemiRingProduct f a
x forall a. Eq a => a -> a -> Bool
/= forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
SemiRingProduct f sr -> IncrHash
prodMapHash SemiRingProduct f b
y = forall a. Maybe a
Nothing
    | Bool
otherwise =
        do a :~: b
Refl <- forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f a
x) (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f b
y)
           forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (forall k a v.
Eq k =>
(a -> a -> Bool)
-> AnnotatedMap k v a -> AnnotatedMap k v a -> Bool
AM.eqBy (forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Bool
SR.occ_eq (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f a
x)) (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f a
x) (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f b
y)) forall a. Maybe a
Nothing
           forall (m :: Type -> Type) a. Monad m => a -> m a
return forall {k} (a :: k). a :~: a
Refl

instance OrdF f => Eq (SemiRingProduct f sr) where
  SemiRingProduct f sr
x == :: SemiRingProduct f sr -> SemiRingProduct f sr -> Bool
== SemiRingProduct f sr
y = forall a. Maybe a -> Bool
isJust (forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality SemiRingProduct f sr
x SemiRingProduct f sr
y)

instance OrdF f => TestEquality (WeightedSum f) where
  testEquality :: forall (a :: SemiRing) (b :: SemiRing).
WeightedSum f a -> WeightedSum f b -> Maybe (a :~: b)
testEquality WeightedSum f a
x WeightedSum f b
y
    | forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
WeightedSum f sr -> IncrHash
sumMapHash WeightedSum f a
x forall a. Eq a => a -> a -> Bool
/= forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
WeightedSum f sr -> IncrHash
sumMapHash WeightedSum f b
y = forall a. Maybe a
Nothing
    | Bool
otherwise =
         do a :~: b
Refl <- forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f a
x) (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f b
y)
            forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f a
x) (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f a
x) (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f b
y)) forall a. Maybe a
Nothing
            forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (forall k a v.
Eq k =>
(a -> a -> Bool)
-> AnnotatedMap k v a -> AnnotatedMap k v a -> Bool
AM.eqBy (forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f a
x)) (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f a
x) (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f b
y)) forall a. Maybe a
Nothing
            forall (m :: Type -> Type) a. Monad m => a -> m a
return forall {k} (a :: k). a :~: a
Refl

instance OrdF f => Eq (WeightedSum f sr) where
  WeightedSum f sr
x == :: WeightedSum f sr -> WeightedSum f sr -> Bool
== WeightedSum f sr
y = forall a. Maybe a -> Bool
isJust (forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality WeightedSum f sr
x WeightedSum f sr
y)


-- | Created a weighted sum directly from a map and constant.
--
-- Note. When calling this, one should ensure map values equal to '0'
-- have been removed.
unfilteredSum ::
  SR.SemiRingRepr sr ->
  SumMap f sr ->
  SR.Coefficient sr ->
  WeightedSum f sr
unfilteredSum :: forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr SumMap f sr
m Coefficient sr
c = forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr
-> Coefficient sr -> SemiRingRepr sr -> WeightedSum f sr
WeightedSum SumMap f sr
m Coefficient sr
c SemiRingRepr sr
sr

-- | Retrieve the mapping from terms to coefficients.
sumMap :: Lens' (WeightedSum f sr) (SumMap f sr)
sumMap :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (SumMap f sr)
sumMap = forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap (\WeightedSum f sr
w AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
m -> WeightedSum f sr
w{ _sumMap :: AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
_sumMap = AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
m })

-- | Retrieve the constant addend of the weighted sum.
sumOffset :: Lens' (WeightedSum f sr) (SR.Coefficient sr)
sumOffset :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset = forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset (\WeightedSum f sr
s Coefficient sr
v -> WeightedSum f sr
s { _sumOffset :: Coefficient sr
_sumOffset = Coefficient sr
v })

instance OrdF f => Hashable (WeightedSum f sr) where
  hashWithSalt :: Int -> WeightedSum f sr -> Int
hashWithSalt Int
s0 WeightedSum f sr
w =
    forall a. Hashable a => Int -> a -> Int
hashWithSalt (forall (sr :: SemiRing).
SemiRingRepr sr -> Int -> Coefficient sr -> Int
SR.sr_hashWithSalt (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
w) Int
s0 (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)) (forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
WeightedSum f sr -> IncrHash
sumMapHash WeightedSum f sr
w)

instance OrdF f => Hashable (SemiRingProduct f sr) where
  hashWithSalt :: Int -> SemiRingProduct f sr -> Int
hashWithSalt Int
s0 SemiRingProduct f sr
w = forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s0 (forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
SemiRingProduct f sr -> IncrHash
prodMapHash SemiRingProduct f sr
w)

-- | Attempt to parse a weighted sum as a constant.
asConstant :: WeightedSum f sr -> Maybe (SR.Coefficient sr)
asConstant :: forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Maybe (Coefficient sr)
asConstant WeightedSum f sr
w
  | forall k v a. AnnotatedMap k v a -> Bool
AM.null (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w) = forall a. a -> Maybe a
Just (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)
  | Bool
otherwise = forall a. Maybe a
Nothing

-- | Return true if a weighted sum is equal to constant 0.
isZero :: SR.SemiRingRepr sr -> WeightedSum f sr -> Bool
isZero :: forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr -> WeightedSum f sr -> Bool
isZero SemiRingRepr sr
sr WeightedSum f sr
s =
   case forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Maybe (Coefficient sr)
asConstant WeightedSum f sr
s of
     Just Coefficient sr
c  -> forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Ordering
SR.sr_compare SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) Coefficient sr
c forall a. Eq a => a -> a -> Bool
== Ordering
EQ
     Maybe (Coefficient sr)
Nothing -> Bool
False

-- | Attempt to parse a weighted sum as a single expression with a coefficient and offset.
--   @asAffineVar w = Just (c,r,o)@ when @denotation(w) = c*r + o@.
asAffineVar :: WeightedSum f sr -> Maybe (SR.Coefficient sr, f (SR.SemiRingBase sr), SR.Coefficient sr)
asAffineVar :: forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr
-> Maybe (Coefficient sr, f (SemiRingBase sr), Coefficient sr)
asAffineVar WeightedSum f sr
w
  | [(WrapF f (SemiRingBase sr)
r, Coefficient sr
c)] <- forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w)
  = forall a. a -> Maybe a
Just (Coefficient sr
c,f (SemiRingBase sr)
r,forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)

  | Bool
otherwise
  = forall a. Maybe a
Nothing

-- | Attempt to parse weighted sum as a single expression with a coefficient.
--   @asWeightedVar w = Just (c,r)@ when @denotation(w) = c*r@.
asWeightedVar :: WeightedSum f sr -> Maybe (SR.Coefficient sr, f (SR.SemiRingBase sr))
asWeightedVar :: forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Maybe (Coefficient sr, f (SemiRingBase sr))
asWeightedVar WeightedSum f sr
w
  | [(WrapF f (SemiRingBase sr)
r, Coefficient sr
c)] <- forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w)
  , let sr :: SemiRingRepr sr
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
w
  , forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)
  = forall a. a -> Maybe a
Just (Coefficient sr
c,f (SemiRingBase sr)
r)

  | Bool
otherwise
  = forall a. Maybe a
Nothing

-- | Attempt to parse a weighted sum as a single expression.
--   @asVar w = Just r@ when @denotation(w) = r@
asVar :: WeightedSum f sr -> Maybe (f (SR.SemiRingBase sr))
asVar :: forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Maybe (f (SemiRingBase sr))
asVar WeightedSum f sr
w
  | [(WrapF f (SemiRingBase sr)
r, Coefficient sr
c)] <- forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w)
  , let sr :: SemiRingRepr sr
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
w
  , forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one SemiRingRepr sr
sr) Coefficient sr
c
  , forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)
  = forall a. a -> Maybe a
Just f (SemiRingBase sr)
r

  | Bool
otherwise
  = forall a. Maybe a
Nothing

-- | Create a sum from a constant coefficient value.
constant :: Tm f => SR.SemiRingRepr sr -> SR.Coefficient sr -> WeightedSum f sr
constant :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr -> Coefficient sr -> WeightedSum f sr
constant SemiRingRepr sr
sr Coefficient sr
c = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty Coefficient sr
c

-- | Traverse the expressions in a weighted sum.
traverseVars :: forall k j m sr.
  (Applicative m, Tm k) =>
  (j (SR.SemiRingBase sr) -> m (k (SR.SemiRingBase sr))) ->
  WeightedSum j sr ->
  m (WeightedSum k sr)
traverseVars :: forall (k :: BaseType -> Type) (j :: BaseType -> Type)
       (m :: Type -> Type) (sr :: SemiRing).
(Applicative m, Tm k) =>
(j (SemiRingBase sr) -> m (k (SemiRingBase sr)))
-> WeightedSum j sr -> m (WeightedSum k sr)
traverseVars j (SemiRingBase sr) -> m (k (SemiRingBase sr))
f WeightedSum j sr
w =
  (\[(k (SemiRingBase sr), Coefficient sr)]
tms -> forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)]
-> Coefficient sr
-> WeightedSum f sr
fromTerms SemiRingRepr sr
sr [(k (SemiRingBase sr), Coefficient sr)]
tms (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum j sr
w)) forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$>
  forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall s t a b. Field1 s t a b => Lens s t a b
_1 j (SemiRingBase sr) -> m (k (SemiRingBase sr))
f) (forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum j sr
w))
  where sr :: SemiRingRepr sr
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum j sr
w

-- | Traverse the coefficients in a weighted sum.
traverseCoeffs :: forall m f sr.
  (Applicative m, Tm f) =>
  (SR.Coefficient sr -> m (SR.Coefficient sr)) ->
  WeightedSum f sr ->
  m (WeightedSum f sr)
traverseCoeffs :: forall (m :: Type -> Type) (f :: BaseType -> Type)
       (sr :: SemiRing).
(Applicative m, Tm f) =>
(Coefficient sr -> m (Coefficient sr))
-> WeightedSum f sr -> m (WeightedSum f sr)
traverseCoeffs Coefficient sr -> m (Coefficient sr)
f WeightedSum f sr
w =
  forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: Type -> Type) k v2 v1 a1 a2.
(Applicative f, Ord k, Semigroup v2) =>
(k -> v1 -> a1 -> f (Maybe (v2, a2)))
-> AnnotatedMap k v1 a1 -> f (AnnotatedMap k v2 a2)
AM.traverseMaybeWithKey WrapF f sr
-> Note sr -> Coefficient sr -> m (Maybe (Note sr, Coefficient sr))
g (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w) forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Coefficient sr -> m (Coefficient sr)
f (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)
  where
    sr :: SemiRingRepr sr
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
w
    g :: WrapF f sr
-> Note sr -> Coefficient sr -> m (Maybe (Note sr, Coefficient sr))
g (WrapF f (SemiRingBase sr)
t) Note sr
_ Coefficient sr
c = f (SemiRingBase sr)
-> Coefficient sr -> Maybe (Note sr, Coefficient sr)
mk f (SemiRingBase sr)
t forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Coefficient sr -> m (Coefficient sr)
f Coefficient sr
c
    mk :: f (SemiRingBase sr)
-> Coefficient sr -> Maybe (Note sr, Coefficient sr)
mk f (SemiRingBase sr)
t Coefficient sr
c = if forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) Coefficient sr
c then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t, Coefficient sr
c)

-- | Traverse the expressions in a product.
traverseProdVars :: forall k j m sr.
  (Applicative m, Tm k) =>
  (j (SR.SemiRingBase sr) -> m (k (SR.SemiRingBase sr))) ->
  SemiRingProduct j sr ->
  m (SemiRingProduct k sr)
traverseProdVars :: forall (k :: BaseType -> Type) (j :: BaseType -> Type)
       (m :: Type -> Type) (sr :: SemiRing).
(Applicative m, Tm k) =>
(j (SemiRingBase sr) -> m (k (SemiRingBase sr)))
-> SemiRingProduct j sr -> m (SemiRingProduct k sr)
traverseProdVars j (SemiRingBase sr) -> m (k (SemiRingBase sr))
f SemiRingProduct j sr
pd =
  forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd SemiRingRepr sr
sr forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(WrapF k sr, Occurrence sr)]
-> AnnotatedMap (WrapF k sr) (ProdNote sr) (Occurrence sr)
rebuild forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$>
    forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall s t a b. Field1 s t a b => Lens s t a b
_1 (forall (m :: Type -> Type) (f :: BaseType -> Type) (i :: SemiRing)
       (g :: BaseType -> Type).
Functor m =>
(f (SemiRingBase i) -> m (g (SemiRingBase i)))
-> WrapF f i -> m (WrapF g i)
traverseWrap j (SemiRingBase sr) -> m (k (SemiRingBase sr))
f)) (forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct j sr
pd))
 where
  sr :: SemiRingRepr sr
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct j sr
pd
  rebuild :: [(WrapF k sr, Occurrence sr)]
-> AnnotatedMap (WrapF k sr) (ProdNote sr) (Occurrence sr)
rebuild = forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\AnnotatedMap (WrapF k sr) (ProdNote sr) (Occurrence sr)
m (WrapF k (SemiRingBase sr)
t, Occurrence sr
occ) -> forall k v a.
(Ord k, Semigroup v) =>
k -> v -> a -> AnnotatedMap k v a -> AnnotatedMap k v a
AM.insert (forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF k (SemiRingBase sr)
t) (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdNote sr
mkProdNote SemiRingRepr sr
sr Occurrence sr
occ k (SemiRingBase sr)
t) Occurrence sr
occ AnnotatedMap (WrapF k sr) (ProdNote sr) (Occurrence sr)
m) forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty


-- | This returns a variable times a constant.
scaledVar :: Tm f => SR.SemiRingRepr sr -> SR.Coefficient sr -> f (SR.SemiRingBase sr) -> WeightedSum f sr
scaledVar :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> WeightedSum f sr
scaledVar SemiRingRepr sr
sr Coefficient sr
s f (SemiRingBase sr)
t
  | forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) Coefficient sr
s = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)
  | Bool
otherwise = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SumMap f sr
singletonSumMap SemiRingRepr sr
sr Coefficient sr
s f (SemiRingBase sr)
t) (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)

-- | Create a weighted sum corresponding to the given variable.
var :: Tm f => SR.SemiRingRepr sr -> f (SR.SemiRingBase sr) -> WeightedSum f sr
var :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr -> f (SemiRingBase sr) -> WeightedSum f sr
var SemiRingRepr sr
sr f (SemiRingBase sr)
t = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SumMap f sr
singletonSumMap SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one SemiRingRepr sr
sr) f (SemiRingBase sr)
t) (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)

-- | Add two sums, collecting terms as necessary and deleting terms whose
--   coefficients sum to 0.
add ::
  Tm f =>
  SR.SemiRingRepr sr ->
  WeightedSum f sr ->
  WeightedSum f sr ->
  WeightedSum f sr
add :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> WeightedSum f sr -> WeightedSum f sr -> WeightedSum f sr
add SemiRingRepr sr
sr WeightedSum f sr
x WeightedSum f sr
y = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
zm Coefficient sr
zc
  where
    merge :: WrapF f sr
-> Coefficient sr
-> Coefficient sr
-> Maybe (Note sr, Coefficient sr)
merge (WrapF f (SemiRingBase sr)
k) Coefficient sr
u Coefficient sr
v | forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr Coefficient sr
r (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) = forall a. Maybe a
Nothing
                        | Bool
otherwise               = forall a. a -> Maybe a
Just (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
r f (SemiRingBase sr)
k, Coefficient sr
r)
      where r :: Coefficient sr
r = forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.add SemiRingRepr sr
sr Coefficient sr
u Coefficient sr
v
    zm :: AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
zm = forall k v a.
(Ord k, Semigroup v) =>
(k -> a -> a -> Maybe (v, a))
-> AnnotatedMap k v a -> AnnotatedMap k v a -> AnnotatedMap k v a
AM.unionWithKeyMaybe WrapF f sr
-> Coefficient sr
-> Coefficient sr
-> Maybe (Note sr, Coefficient sr)
merge (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
x) (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
y)
    zc :: Coefficient sr
zc = forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.add SemiRingRepr sr
sr (WeightedSum f sr
xforall s a. s -> Getting a s a -> a
^.forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset) (WeightedSum f sr
yforall s a. s -> Getting a s a -> a
^.forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset)

-- | Create a weighted sum that represents the sum of two terms.
addVars ::
  Tm f =>
  SR.SemiRingRepr sr ->
  f (SR.SemiRingBase sr) ->
  f (SR.SemiRingBase sr) ->
  WeightedSum f sr
addVars :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> f (SemiRingBase sr) -> f (SemiRingBase sr) -> WeightedSum f sr
addVars SemiRingRepr sr
sr f (SemiRingBase sr)
x f (SemiRingBase sr)
y = forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)]
-> Coefficient sr
-> WeightedSum f sr
fromTerms SemiRingRepr sr
sr [(f (SemiRingBase sr)
x, forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one SemiRingRepr sr
sr), (f (SemiRingBase sr)
y, forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one SemiRingRepr sr
sr)] (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)

-- | Add a variable to the sum.
addVar ::
  Tm f =>
  SR.SemiRingRepr sr ->
  WeightedSum f sr -> f (SR.SemiRingBase sr) -> WeightedSum f sr
addVar :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> WeightedSum f sr -> f (SemiRingBase sr) -> WeightedSum f sr
addVar SemiRingRepr sr
sr WeightedSum f sr
wsum f (SemiRingBase sr)
x = WeightedSum f sr
wsum { _sumMap :: SumMap f sr
_sumMap = SumMap f sr
m' }
  where m' :: SumMap f sr
m' = forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr
-> f (SemiRingBase sr)
-> SumMap f sr
-> SumMap f sr
insertSumMap SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one SemiRingRepr sr
sr) f (SemiRingBase sr)
x (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
wsum)

-- | Add a constant to the sum.
addConstant :: SR.SemiRingRepr sr -> WeightedSum f sr -> SR.Coefficient sr -> WeightedSum f sr
addConstant :: forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> WeightedSum f sr -> Coefficient sr -> WeightedSum f sr
addConstant SemiRingRepr sr
sr WeightedSum f sr
x Coefficient sr
r = WeightedSum f sr
x forall a b. a -> (a -> b) -> b
& forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.add SemiRingRepr sr
sr Coefficient sr
r

-- | Multiply a sum by a constant coefficient.
scale :: Tm f => SR.SemiRingRepr sr -> SR.Coefficient sr -> WeightedSum f sr -> WeightedSum f sr
scale :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr -> WeightedSum f sr -> WeightedSum f sr
scale SemiRingRepr sr
sr Coefficient sr
c WeightedSum f sr
wsum
  | forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr Coefficient sr
c (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) = forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr -> Coefficient sr -> WeightedSum f sr
constant SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)
  | Bool
otherwise = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
m' (forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.mul SemiRingRepr sr
sr Coefficient sr
c (WeightedSum f sr
wsumforall s a. s -> Getting a s a -> a
^.forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset))
  where
    m' :: AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
m' = forall k v2 v1 a1 a2.
(Ord k, Semigroup v2) =>
(k -> v1 -> a1 -> Maybe (v2, a2))
-> AnnotatedMap k v1 a1 -> AnnotatedMap k v2 a2
AM.mapMaybeWithKey WrapF f sr
-> Note sr -> Coefficient sr -> Maybe (Note sr, Coefficient sr)
f (WeightedSum f sr
wsumforall s a. s -> Getting a s a -> a
^.forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (SumMap f sr)
sumMap)
    f :: WrapF f sr
-> Note sr -> Coefficient sr -> Maybe (Note sr, Coefficient sr)
f (WrapF f (SemiRingBase sr)
t) Note sr
_ Coefficient sr
x
      | forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) Coefficient sr
cx = forall a. Maybe a
Nothing
      | Bool
otherwise = forall a. a -> Maybe a
Just (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
cx f (SemiRingBase sr)
t, Coefficient sr
cx)
      where cx :: Coefficient sr
cx = forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.mul SemiRingRepr sr
sr Coefficient sr
c Coefficient sr
x

-- | Produce a weighted sum from a list of terms and an offset.
fromTerms ::
  Tm f =>
  SR.SemiRingRepr sr ->
  [(f (SR.SemiRingBase sr), SR.Coefficient sr)] ->
  SR.Coefficient sr ->
  WeightedSum f sr
fromTerms :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)]
-> Coefficient sr
-> WeightedSum f sr
fromTerms SemiRingRepr sr
sr [(f (SemiRingBase sr), Coefficient sr)]
tms Coefficient sr
offset = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)] -> SumMap f sr
fromListSumMap SemiRingRepr sr
sr [(f (SemiRingBase sr), Coefficient sr)]
tms) Coefficient sr
offset

-- | Apply update functions to the terms and coefficients of a weighted sum.
transformSum :: (Applicative m, Tm g) =>
  SR.SemiRingRepr sr' ->
  (SR.Coefficient sr -> m (SR.Coefficient sr')) ->
  (f (SR.SemiRingBase sr) -> m (g (SR.SemiRingBase sr'))) ->
  WeightedSum f sr ->
  m (WeightedSum g sr')
transformSum :: forall (m :: Type -> Type) (g :: BaseType -> Type)
       (sr' :: SemiRing) (sr :: SemiRing) (f :: BaseType -> Type).
(Applicative m, Tm g) =>
SemiRingRepr sr'
-> (Coefficient sr -> m (Coefficient sr'))
-> (f (SemiRingBase sr) -> m (g (SemiRingBase sr')))
-> WeightedSum f sr
-> m (WeightedSum g sr')
transformSum SemiRingRepr sr'
sr' Coefficient sr -> m (Coefficient sr')
transCoef f (SemiRingBase sr) -> m (g (SemiRingBase sr'))
transTm WeightedSum f sr
s = forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)]
-> Coefficient sr
-> WeightedSum f sr
fromTerms SemiRingRepr sr'
sr' forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> m [(g (SemiRingBase sr'), Coefficient sr')]
tms forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> m (Coefficient sr')
c
  where
    f :: (f (SemiRingBase sr), Coefficient sr)
-> m (g (SemiRingBase sr'), Coefficient sr')
f (f (SemiRingBase sr)
t, Coefficient sr
x) = (,) forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> f (SemiRingBase sr) -> m (g (SemiRingBase sr'))
transTm f (SemiRingBase sr)
t forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Coefficient sr -> m (Coefficient sr')
transCoef Coefficient sr
x
    tms :: m [(g (SemiRingBase sr'), Coefficient sr')]
tms = forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (f (SemiRingBase sr), Coefficient sr)
-> m (g (SemiRingBase sr'), Coefficient sr')
f (forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
s))
    c :: m (Coefficient sr')
c   = Coefficient sr -> m (Coefficient sr')
transCoef (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
s)


-- | Evaluate a sum given interpretations of addition, scalar
-- multiplication, and a constant. This evaluation is threaded through
-- a monad. The addition function is associated to the left, as in
-- 'foldlM'.
evalM :: Monad m =>
  (r -> r -> m r) {- ^ Addition function -} ->
  (SR.Coefficient sr -> f (SR.SemiRingBase sr) -> m r) {- ^ Scalar multiply -} ->
  (SR.Coefficient sr -> m r) {- ^ Constant evaluation -} ->
  WeightedSum f sr ->
  m r
evalM :: forall (m :: Type -> Type) r (sr :: SemiRing)
       (f :: BaseType -> Type).
Monad m =>
(r -> r -> m r)
-> (Coefficient sr -> f (SemiRingBase sr) -> m r)
-> (Coefficient sr -> m r)
-> WeightedSum f sr
-> m r
evalM r -> r -> m r
addFn Coefficient sr -> f (SemiRingBase sr) -> m r
smul Coefficient sr -> m r
cnst WeightedSum f sr
sm
  | forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
sm) (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) =
      case forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
sm) of
        []             -> Coefficient sr -> m r
cnst (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)
        ((f (SemiRingBase sr)
e, Coefficient sr
s) : [(f (SemiRingBase sr), Coefficient sr)]
tms) -> [(f (SemiRingBase sr), Coefficient sr)] -> r -> m r
go [(f (SemiRingBase sr), Coefficient sr)]
tms forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Coefficient sr -> f (SemiRingBase sr) -> m r
smul Coefficient sr
s f (SemiRingBase sr)
e

  | Bool
otherwise =
      [(f (SemiRingBase sr), Coefficient sr)] -> r -> m r
go (forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
sm)) forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Coefficient sr -> m r
cnst (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
sm)

  where
    sr :: SemiRingRepr sr
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
sm

    go :: [(f (SemiRingBase sr), Coefficient sr)] -> r -> m r
go [] r
x = forall (m :: Type -> Type) a. Monad m => a -> m a
return r
x
    go ((f (SemiRingBase sr)
e, Coefficient sr
s) : [(f (SemiRingBase sr), Coefficient sr)]
tms) r
x = [(f (SemiRingBase sr), Coefficient sr)] -> r -> m r
go [(f (SemiRingBase sr), Coefficient sr)]
tms forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< r -> r -> m r
addFn r
x forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Coefficient sr -> f (SemiRingBase sr) -> m r
smul Coefficient sr
s f (SemiRingBase sr)
e

-- | Evaluate a sum given interpretations of addition, scalar multiplication, and
-- a constant rational.
eval ::
  (r -> r -> r) {- ^ Addition function -} ->
  (SR.Coefficient sr -> f (SR.SemiRingBase sr) -> r) {- ^ Scalar multiply -} ->
  (SR.Coefficient sr -> r) {- ^ Constant evaluation -} ->
  WeightedSum f sr ->
  r
eval :: forall r (sr :: SemiRing) (f :: BaseType -> Type).
(r -> r -> r)
-> (Coefficient sr -> f (SemiRingBase sr) -> r)
-> (Coefficient sr -> r)
-> WeightedSum f sr
-> r
eval r -> r -> r
addFn Coefficient sr -> f (SemiRingBase sr) -> r
smul Coefficient sr -> r
cnst WeightedSum f sr
w
  | forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w) (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) =
      case forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w) of
        []             -> Coefficient sr -> r
cnst (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)
        ((f (SemiRingBase sr)
e, Coefficient sr
s) : [(f (SemiRingBase sr), Coefficient sr)]
tms) -> [(f (SemiRingBase sr), Coefficient sr)] -> r -> r
go [(f (SemiRingBase sr), Coefficient sr)]
tms (Coefficient sr -> f (SemiRingBase sr) -> r
smul Coefficient sr
s f (SemiRingBase sr)
e)

  | Bool
otherwise =
      [(f (SemiRingBase sr), Coefficient sr)] -> r -> r
go (forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w)) (Coefficient sr -> r
cnst (forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w))

  where
    sr :: SemiRingRepr sr
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
w

    go :: [(f (SemiRingBase sr), Coefficient sr)] -> r -> r
go [] r
x = r
x
    go ((f (SemiRingBase sr)
e, Coefficient sr
s) : [(f (SemiRingBase sr), Coefficient sr)]
tms) r
x = [(f (SemiRingBase sr), Coefficient sr)] -> r -> r
go [(f (SemiRingBase sr), Coefficient sr)]
tms (r -> r -> r
addFn (Coefficient sr -> f (SemiRingBase sr) -> r
smul Coefficient sr
s f (SemiRingBase sr)
e) r
x)

{-# INLINABLE eval #-}


-- | Reduce a weighted sum of integers modulo a concrete integer.
--   This reduces each of the coefficients modulo the given integer,
--   removing any that are congruent to 0; the offset value is
--   also reduced.
reduceIntSumMod ::
  Tm f =>
  WeightedSum f SR.SemiRingInteger {- ^ The sum to reduce -} ->
  Integer {- ^ The modulus, must not be 0 -} ->
  WeightedSum f SR.SemiRingInteger
reduceIntSumMod :: forall (f :: BaseType -> Type).
Tm f =>
WeightedSum f 'SemiRingInteger
-> Integer -> WeightedSum f 'SemiRingInteger
reduceIntSumMod WeightedSum f 'SemiRingInteger
ws Integer
k = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr 'SemiRingInteger
SR.SemiRingIntegerRepr AnnotatedMap
  (WrapF f 'SemiRingInteger) (Note 'SemiRingInteger) Integer
m (WeightedSum f 'SemiRingInteger
wsforall s a. s -> Getting a s a -> a
^.forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset forall a. Integral a => a -> a -> a
`mod` Integer
k)
  where
    sr :: SemiRingRepr 'SemiRingInteger
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f 'SemiRingInteger
ws
    m :: AnnotatedMap
  (WrapF f 'SemiRingInteger) (Note 'SemiRingInteger) Integer
m = forall a. Identity a -> a
runIdentity (forall (f :: Type -> Type) k v2 v1 a1 a2.
(Applicative f, Ord k, Semigroup v2) =>
(k -> v1 -> a1 -> f (Maybe (v2, a2)))
-> AnnotatedMap k v1 a1 -> f (AnnotatedMap k v2 a2)
AM.traverseMaybeWithKey WrapF f 'SemiRingInteger
-> Note 'SemiRingInteger
-> Integer
-> Identity (Maybe (Note 'SemiRingInteger, Integer))
f (WeightedSum f 'SemiRingInteger
wsforall s a. s -> Getting a s a -> a
^.forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (SumMap f sr)
sumMap))
    f :: WrapF f 'SemiRingInteger
-> Note 'SemiRingInteger
-> Integer
-> Identity (Maybe (Note 'SemiRingInteger, Integer))
f (WrapF f (SemiRingBase 'SemiRingInteger)
t) Note 'SemiRingInteger
_ Integer
x
      | Integer
x' forall a. Eq a => a -> a -> Bool
== Integer
0   = forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
      | Bool
otherwise = forall (m :: Type -> Type) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr 'SemiRingInteger
sr Integer
x' f (SemiRingBase 'SemiRingInteger)
t, Integer
x'))
      where x' :: Integer
x' = Integer
x forall a. Integral a => a -> a -> a
`mod` Integer
k

{-# INLINABLE extractCommon #-}

-- | Given two weighted sums @x@ and @y@, this returns a triple @(z,x',y')@
-- where @x = z + x'@ and @y = z + y'@ and @z@ contains the "common"
-- parts of @x@ and @y@.  We only extract common terms when both
-- terms occur with the same coefficient in each sum.
--
-- This is primarily used to simplify if-then-else expressions to
-- preserve shared subterms.
extractCommon ::
  Tm f =>
  WeightedSum f sr ->
  WeightedSum f sr ->
  (WeightedSum f sr, WeightedSum f sr, WeightedSum f sr)
extractCommon :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
WeightedSum f sr
-> WeightedSum f sr
-> (WeightedSum f sr, WeightedSum f sr, WeightedSum f sr)
extractCommon (WeightedSum SumMap f sr
xm Coefficient sr
xc SemiRingRepr sr
sr) (WeightedSum SumMap f sr
ym Coefficient sr
yc SemiRingRepr sr
_) = (WeightedSum f sr
z, WeightedSum f sr
x', WeightedSum f sr
y')
  where
    mergeCommon :: WrapF f sr
-> (Note sr, Coefficient sr)
-> (Note sr, Coefficient sr)
-> Maybe (Note sr, Coefficient sr)
mergeCommon (WrapF f (SemiRingBase sr)
t) (Note sr
_, Coefficient sr
xv) (Note sr
_, Coefficient sr
yv)
      | forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr Coefficient sr
xv Coefficient sr
yv  = forall a. a -> Maybe a
Just (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
xv f (SemiRingBase sr)
t, Coefficient sr
xv)
      | Bool
otherwise       = forall a. Maybe a
Nothing

    zm :: SumMap f sr
zm = forall k u v w a b c.
(Ord k, Semigroup u, Semigroup v, Semigroup w) =>
(k -> (u, a) -> (v, b) -> Maybe (w, c))
-> (AnnotatedMap k u a -> AnnotatedMap k w c)
-> (AnnotatedMap k v b -> AnnotatedMap k w c)
-> AnnotatedMap k u a
-> AnnotatedMap k v b
-> AnnotatedMap k w c
AM.mergeWithKey WrapF f sr
-> (Note sr, Coefficient sr)
-> (Note sr, Coefficient sr)
-> Maybe (Note sr, Coefficient sr)
mergeCommon (forall a b. a -> b -> a
const forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty) (forall a b. a -> b -> a
const forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty) SumMap f sr
xm SumMap f sr
ym

    (Coefficient sr
zc, Coefficient sr
xc', Coefficient sr
yc')
      | forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr Coefficient sr
xc Coefficient sr
yc = (Coefficient sr
xc, forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr, forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)
      | Bool
otherwise      = (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr, Coefficient sr
xc, Coefficient sr
yc)

    z :: WeightedSum f sr
z = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr SumMap f sr
zm Coefficient sr
zc

    x' :: WeightedSum f sr
x' = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (SumMap f sr
xm forall k v w a b.
(Ord k, Semigroup v, Semigroup w) =>
AnnotatedMap k v a -> AnnotatedMap k w b -> AnnotatedMap k v a
`AM.difference` SumMap f sr
zm) Coefficient sr
xc'
    y' :: WeightedSum f sr
y' = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (SumMap f sr
ym forall k v w a b.
(Ord k, Semigroup v, Semigroup w) =>
AnnotatedMap k v a -> AnnotatedMap k w b -> AnnotatedMap k v a
`AM.difference` SumMap f sr
zm) Coefficient sr
yc'


-- | Returns true if the product is trivial (contains no terms).
nullProd :: SemiRingProduct f sr -> Bool
nullProd :: forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> Bool
nullProd SemiRingProduct f sr
pd = forall k v a. AnnotatedMap k v a -> Bool
AM.null (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
pd)

-- | If the product consists of exactly on term, return it.
asProdVar :: SemiRingProduct f sr -> Maybe (f (SR.SemiRingBase sr))
asProdVar :: forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> Maybe (f (SemiRingBase sr))
asProdVar SemiRingProduct f sr
pd
  | [(WrapF f (SemiRingBase sr)
x, forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Natural
SR.occ_count SemiRingRepr sr
sr -> Natural
1)] <- forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
pd) = forall a. a -> Maybe a
Just f (SemiRingBase sr)
x
  | Bool
otherwise = forall a. Maybe a
Nothing
 where
 sr :: SemiRingRepr sr
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f sr
pd

prodAbsValue :: OrdF f => SemiRingProduct f sr -> AD.AbstractValue (SR.SemiRingBase sr)
prodAbsValue :: forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
SemiRingProduct f sr -> AbstractValue (SemiRingBase sr)
prodAbsValue SemiRingProduct f sr
pd =
  forall (sr :: SemiRing).
SRAbsValue sr -> AbstractValue (SemiRingBase sr)
fromSRAbsValue forall a b. (a -> b) -> a -> b
$
  case forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a -> Maybe v
AM.annotation (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
pd) of
    Maybe (ProdNote sr)
Nothing             -> forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> SRAbsValue sr
abstractScalar (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f sr
pd) (forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f sr
pd))
    Just (ProdNote IncrHash
_ SRAbsValue sr
v) -> SRAbsValue sr
v

-- | Returns true if the product contains at least on occurrence of the given term.
prodContains :: OrdF f => SemiRingProduct f sr -> f (SR.SemiRingBase sr) -> Bool
prodContains :: forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
SemiRingProduct f sr -> f (SemiRingBase sr) -> Bool
prodContains SemiRingProduct f sr
pd f (SemiRingBase sr)
x = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall k v a.
(Ord k, Semigroup v) =>
k -> AnnotatedMap k v a -> Maybe (v, a)
AM.lookup (forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF f (SemiRingBase sr)
x) (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
pd)

-- | Produce a product map from a raw map of terms to occurrences.
--   PRECONDITION: the occurrence value for each term should be non-zero.
mkProd :: SR.SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd :: forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd SemiRingRepr sr
sr ProdMap f sr
m = forall (f :: BaseType -> Type) (sr :: SemiRing).
ProdMap f sr -> SemiRingRepr sr -> SemiRingProduct f sr
SemiRingProduct ProdMap f sr
m SemiRingRepr sr
sr

-- | Produce a product representing the single given term.
prodVar :: Tm f => SR.SemiRingRepr sr -> f (SR.SemiRingBase sr) -> SemiRingProduct f sr
prodVar :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr -> f (SemiRingBase sr) -> SemiRingProduct f sr
prodVar SemiRingRepr sr
sr f (SemiRingBase sr)
x = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd SemiRingRepr sr
sr (forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdMap f sr
singletonProdMap SemiRingRepr sr
sr (forall (sr :: SemiRing). SemiRingRepr sr -> Occurrence sr
SR.occ_one SemiRingRepr sr
sr) f (SemiRingBase sr)
x)

-- | Multiply two products, collecting terms and adding occurrences.
prodMul :: Tm f => SemiRingProduct f sr -> SemiRingProduct f sr -> SemiRingProduct f sr
prodMul :: forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingProduct f sr
-> SemiRingProduct f sr -> SemiRingProduct f sr
prodMul SemiRingProduct f sr
x SemiRingProduct f sr
y = forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd SemiRingRepr sr
sr AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
m
  where
  sr :: SemiRingRepr sr
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f sr
x
  mergeCommon :: WrapF f sr
-> (ProdNote sr, Occurrence sr)
-> (ProdNote sr, Occurrence sr)
-> Maybe (ProdNote sr, Occurrence sr)
mergeCommon (WrapF f (SemiRingBase sr)
k) (ProdNote sr
_,Occurrence sr
a) (ProdNote sr
_,Occurrence sr
b) = forall a. a -> Maybe a
Just (forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdNote sr
mkProdNote SemiRingRepr sr
sr Occurrence sr
c f (SemiRingBase sr)
k, Occurrence sr
c)
     where c :: Occurrence sr
c = forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Occurrence sr
SR.occ_add SemiRingRepr sr
sr Occurrence sr
a Occurrence sr
b
  m :: AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
m = forall k u v w a b c.
(Ord k, Semigroup u, Semigroup v, Semigroup w) =>
(k -> (u, a) -> (v, b) -> Maybe (w, c))
-> (AnnotatedMap k u a -> AnnotatedMap k w c)
-> (AnnotatedMap k v b -> AnnotatedMap k w c)
-> AnnotatedMap k u a
-> AnnotatedMap k v b
-> AnnotatedMap k w c
AM.mergeWithKey WrapF f sr
-> (ProdNote sr, Occurrence sr)
-> (ProdNote sr, Occurrence sr)
-> Maybe (ProdNote sr, Occurrence sr)
mergeCommon forall a. a -> a
id forall a. a -> a
id (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
x) (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
y)

-- | Evaluate a product, given a function representing multiplication
--   and a function to evaluate terms.
prodEval ::
  (r -> r -> r) {-^ multiplication evalation -} ->
  (f (SR.SemiRingBase sr) -> r) {-^ term evaluation -} ->
  SemiRingProduct f sr ->
  Maybe r
prodEval :: forall r (f :: BaseType -> Type) (sr :: SemiRing).
(r -> r -> r)
-> (f (SemiRingBase sr) -> r) -> SemiRingProduct f sr -> Maybe r
prodEval r -> r -> r
mul f (SemiRingBase sr) -> r
tm SemiRingProduct f sr
om =
  forall a. Identity a -> a
runIdentity (forall (m :: Type -> Type) r (f :: BaseType -> Type)
       (sr :: SemiRing).
Monad m =>
(r -> r -> m r)
-> (f (SemiRingBase sr) -> m r)
-> SemiRingProduct f sr
-> m (Maybe r)
prodEvalM (\r
x r
y -> forall a. a -> Identity a
Identity (r -> r -> r
mul r
x r
y)) (forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (SemiRingBase sr) -> r
tm) SemiRingProduct f sr
om)

-- | Evaluate a product, given a function representing multiplication
--   and a function to evaluate terms, where both functions are threaded
--   through a monad.
prodEvalM :: Monad m =>
  (r -> r -> m r) {-^ multiplication evalation -} ->
  (f (SR.SemiRingBase sr) -> m r) {-^ term evaluation -} ->
  SemiRingProduct f sr ->
  m (Maybe r)
prodEvalM :: forall (m :: Type -> Type) r (f :: BaseType -> Type)
       (sr :: SemiRing).
Monad m =>
(r -> r -> m r)
-> (f (SemiRingBase sr) -> m r)
-> SemiRingProduct f sr
-> m (Maybe r)
prodEvalM r -> r -> m r
mul f (SemiRingBase sr) -> m r
tm SemiRingProduct f sr
om = [(WrapF f sr, Occurrence sr)] -> m (Maybe r)
f (forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
om))
  where
  sr :: SemiRingRepr sr
sr = forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f sr
om

  -- we have not yet encountered a term with non-zero occurrences
  f :: [(WrapF f sr, Occurrence sr)] -> m (Maybe r)
f [] = forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
  f ((WrapF f (SemiRingBase sr)
x, forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Natural
SR.occ_count SemiRingRepr sr
sr -> Natural
n):[(WrapF f sr, Occurrence sr)]
xs)
    | Natural
n forall a. Eq a => a -> a -> Bool
== Natural
0    = [(WrapF f sr, Occurrence sr)] -> m (Maybe r)
f [(WrapF f sr, Occurrence sr)]
xs
    | Bool
otherwise =
        do r
t <- f (SemiRingBase sr) -> m r
tm f (SemiRingBase sr)
x
           r
t' <- Natural -> r -> r -> m r
go (Natural
nforall a. Num a => a -> a -> a
-Natural
1) r
t r
t
           [(WrapF f sr, Occurrence sr)] -> r -> m (Maybe r)
g [(WrapF f sr, Occurrence sr)]
xs r
t'

  -- we have a partial product @z@ already computed and need to multiply
  -- in the remaining terms in the list
  g :: [(WrapF f sr, Occurrence sr)] -> r -> m (Maybe r)
g [] r
z = forall (m :: Type -> Type) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just r
z)
  g ((WrapF f (SemiRingBase sr)
x, forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Natural
SR.occ_count SemiRingRepr sr
sr -> Natural
n):[(WrapF f sr, Occurrence sr)]
xs) r
z
    | Natural
n forall a. Eq a => a -> a -> Bool
== Natural
0    = [(WrapF f sr, Occurrence sr)] -> r -> m (Maybe r)
g [(WrapF f sr, Occurrence sr)]
xs r
z
    | Bool
otherwise =
        do r
t <- f (SemiRingBase sr) -> m r
tm f (SemiRingBase sr)
x
           r
t' <- Natural -> r -> r -> m r
go Natural
n r
t r
z
           [(WrapF f sr, Occurrence sr)] -> r -> m (Maybe r)
g [(WrapF f sr, Occurrence sr)]
xs r
t'

  -- compute: z * t^n
  go :: Natural -> r -> r -> m r
go Natural
n r
t r
z
    | Natural
n forall a. Ord a => a -> a -> Bool
> Natural
0 = Natural -> r -> r -> m r
go (Natural
nforall a. Num a => a -> a -> a
-Natural
1) r
t forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< r -> r -> m r
mul r
z r
t
    | Bool
otherwise = forall (m :: Type -> Type) a. Monad m => a -> m a
return r
z