{-|
Module      : What4.Expr.WeightedSum
Description : Representations for weighted sums and products in semirings
Copyright   : (c) Galois Inc, 2015-2020
License     : BSD3
Maintainer  : jhendrix@galois.com

Declares a weighted sum type used for representing sums over variables and an offset
in one of the supported semirings.  This module also implements a representation of
semiring products.
-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wwarn #-}
module What4.Expr.WeightedSum
  ( -- * Utilities
    Tm
    -- * Weighted sums
  , WeightedSum
  , sumRepr
  , sumOffset
  , sumAbsValue
  , constant
  , var
  , scaledVar
  , asConstant
  , asVar
  , asWeightedVar
  , asAffineVar
  , isZero
  , traverseVars
  , traverseCoeffs
  , add
  , addVar
  , addVars
  , addConstant
  , scale
  , eval
  , evalM
  , extractCommon
  , fromTerms
  , transformSum
  , reduceIntSumMod

    -- * Ring products
  , SemiRingProduct
  , traverseProdVars
  , nullProd
  , asProdVar
  , prodRepr
  , prodVar
  , prodAbsValue
  , prodMul
  , prodEval
  , prodEvalM
  , prodContains
  ) where

import           Control.Lens
import           Control.Monad.State
import qualified Data.BitVector.Sized as BV
import           Data.Hashable
import           Data.Kind
import           Data.List (foldl')
import           Data.Maybe
import           Data.Parameterized.Classes

import           What4.BaseTypes
import qualified What4.SemiRing as SR
import           What4.Utils.AnnotatedMap (AnnotatedMap)
import qualified What4.Utils.AnnotatedMap as AM
import qualified What4.Utils.AbstractDomains as AD
import qualified What4.Utils.BVDomain.Arith as A
import qualified What4.Utils.BVDomain.XOR as X
import qualified What4.Utils.BVDomain as BVD

import           What4.Utils.IncrHash

--------------------------------------------------------------------------------

data SRAbsValue :: SR.SemiRing -> Type where
  SRAbsIntAdd  :: !(AD.ValueRange Integer)  -> SRAbsValue SR.SemiRingInteger
  SRAbsRealAdd :: !AD.RealAbstractValue     -> SRAbsValue SR.SemiRingReal
  SRAbsBVAdd   :: (1 <= w) => !(A.Domain w) -> SRAbsValue (SR.SemiRingBV SR.BVArith w)
  SRAbsBVXor   :: (1 <= w) => !(X.Domain w) -> SRAbsValue (SR.SemiRingBV SR.BVBits w)

instance Semigroup (SRAbsValue sr) where
  SRAbsIntAdd  ValueRange Integer
x <> :: SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
<> SRAbsIntAdd  ValueRange Integer
y = ValueRange Integer -> SRAbsValue SemiRingInteger
SRAbsIntAdd  (ValueRange Integer -> ValueRange Integer -> ValueRange Integer
forall tp.
Num tp =>
ValueRange tp -> ValueRange tp -> ValueRange tp
AD.addRange ValueRange Integer
x ValueRange Integer
y)
  SRAbsRealAdd RealAbstractValue
x <> SRAbsRealAdd RealAbstractValue
y = RealAbstractValue -> SRAbsValue SemiRingReal
SRAbsRealAdd (RealAbstractValue -> RealAbstractValue -> RealAbstractValue
AD.ravAdd RealAbstractValue
x RealAbstractValue
y)
  SRAbsBVAdd   Domain w
x <> SRAbsBVAdd   Domain w
y = Domain w -> SRAbsValue (SemiRingBV BVArith w)
forall (w :: Nat).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVArith w)
SRAbsBVAdd   (Domain w -> Domain w -> Domain w
forall (w :: Nat). (1 <= w) => Domain w -> Domain w -> Domain w
A.add Domain w
x Domain w
Domain w
y)
  SRAbsBVXor   Domain w
x <> SRAbsBVXor   Domain w
y = Domain w -> SRAbsValue (SemiRingBV BVBits w)
forall (w :: Nat).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVBits w)
SRAbsBVXor   (Domain w -> Domain w -> Domain w
forall (w :: Nat). Domain w -> Domain w -> Domain w
X.xor Domain w
x Domain w
Domain w
y)


(.**) :: SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
SRAbsIntAdd  ValueRange Integer
x .** :: SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
.** SRAbsIntAdd  ValueRange Integer
y = ValueRange Integer -> SRAbsValue SemiRingInteger
SRAbsIntAdd  (ValueRange Integer -> ValueRange Integer -> ValueRange Integer
forall tp.
(Ord tp, Num tp) =>
ValueRange tp -> ValueRange tp -> ValueRange tp
AD.mulRange ValueRange Integer
x ValueRange Integer
y)
SRAbsRealAdd RealAbstractValue
x .** SRAbsRealAdd RealAbstractValue
y = RealAbstractValue -> SRAbsValue SemiRingReal
SRAbsRealAdd (RealAbstractValue -> RealAbstractValue -> RealAbstractValue
AD.ravMul RealAbstractValue
x RealAbstractValue
y)
SRAbsBVAdd   Domain w
x .** SRAbsBVAdd   Domain w
y = Domain w -> SRAbsValue (SemiRingBV BVArith w)
forall (w :: Nat).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVArith w)
SRAbsBVAdd   (Domain w -> Domain w -> Domain w
forall (w :: Nat). (1 <= w) => Domain w -> Domain w -> Domain w
A.mul Domain w
x Domain w
Domain w
y)
SRAbsBVXor   Domain w
x .** SRAbsBVXor   Domain w
y = Domain w -> SRAbsValue (SemiRingBV BVBits w)
forall (w :: Nat).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVBits w)
SRAbsBVXor   (Domain w -> Domain w -> Domain w
forall (w :: Nat). Domain w -> Domain w -> Domain w
X.and Domain w
x Domain w
Domain w
y)

abstractTerm ::
  AD.HasAbsValue f =>
  SR.SemiRingRepr sr -> SR.Coefficient sr -> f (SR.SemiRingBase sr) -> SRAbsValue sr
abstractTerm :: SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SRAbsValue sr
abstractTerm SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
e =
  case SemiRingRepr sr
sr of
    SemiRingRepr sr
SR.SemiRingIntegerRepr -> ValueRange Integer -> SRAbsValue SemiRingInteger
SRAbsIntAdd (Integer -> ValueRange Integer -> ValueRange Integer
forall tp. (Ord tp, Num tp) => tp -> ValueRange tp -> ValueRange tp
AD.rangeScalarMul Integer
Coefficient sr
c (f BaseIntegerType -> AbstractValue BaseIntegerType
forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f BaseIntegerType
f (SemiRingBase sr)
e))
    SemiRingRepr sr
SR.SemiRingRealRepr    -> RealAbstractValue -> SRAbsValue SemiRingReal
SRAbsRealAdd (Rational -> RealAbstractValue -> RealAbstractValue
AD.ravScalarMul Rational
Coefficient sr
c (f BaseRealType -> AbstractValue BaseRealType
forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f BaseRealType
f (SemiRingBase sr)
e))
    SR.SemiRingBVRepr BVFlavorRepr fv
fv NatRepr w
w ->
      case BVFlavorRepr fv
fv of
        BVFlavorRepr fv
SR.BVArithRepr ->
          -- A.scale expects a signed integer coefficient
          Domain w -> SRAbsValue (SemiRingBV BVArith w)
forall (w :: Nat).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVArith w)
SRAbsBVAdd (Integer -> Domain w -> Domain w
forall (w :: Nat). (1 <= w) => Integer -> Domain w -> Domain w
A.scale (NatRepr w -> BV w -> Integer
forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
BV.asSigned NatRepr w
w BV w
Coefficient sr
c) (BVDomain w -> Domain w
forall (w :: Nat). BVDomain w -> Domain w
BVD.asArithDomain (f (BaseBVType w) -> AbstractValue (BaseBVType w)
forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (BaseBVType w)
f (SemiRingBase sr)
e)))
        BVFlavorRepr fv
SR.BVBitsRepr  -> Domain w -> SRAbsValue (SemiRingBV BVBits w)
forall (w :: Nat).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVBits w)
SRAbsBVXor (Integer -> Domain w -> Domain w
forall (w :: Nat). Integer -> Domain w -> Domain w
X.and_scalar (BV w -> Integer
forall (w :: Nat). BV w -> Integer
BV.asUnsigned BV w
Coefficient sr
c) (BVDomain w -> Domain w
forall (w :: Nat). BVDomain w -> Domain w
BVD.asXorDomain (f (BaseBVType w) -> AbstractValue (BaseBVType w)
forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (BaseBVType w)
f (SemiRingBase sr)
e)))

abstractVal :: AD.HasAbsValue f => SR.SemiRingRepr sr -> f (SR.SemiRingBase sr) -> SRAbsValue sr
abstractVal :: SemiRingRepr sr -> f (SemiRingBase sr) -> SRAbsValue sr
abstractVal SemiRingRepr sr
sr f (SemiRingBase sr)
e =
  case SemiRingRepr sr
sr of
    SemiRingRepr sr
SR.SemiRingIntegerRepr -> ValueRange Integer -> SRAbsValue SemiRingInteger
SRAbsIntAdd (f BaseIntegerType -> AbstractValue BaseIntegerType
forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f BaseIntegerType
f (SemiRingBase sr)
e)
    SemiRingRepr sr
SR.SemiRingRealRepr    -> RealAbstractValue -> SRAbsValue SemiRingReal
SRAbsRealAdd (f BaseRealType -> AbstractValue BaseRealType
forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f BaseRealType
f (SemiRingBase sr)
e)
    SR.SemiRingBVRepr BVFlavorRepr fv
fv NatRepr w
_w ->
      case BVFlavorRepr fv
fv of
        BVFlavorRepr fv
SR.BVArithRepr -> Domain w -> SRAbsValue (SemiRingBV BVArith w)
forall (w :: Nat).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVArith w)
SRAbsBVAdd (BVDomain w -> Domain w
forall (w :: Nat). BVDomain w -> Domain w
BVD.asArithDomain (f (BaseBVType w) -> AbstractValue (BaseBVType w)
forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (BaseBVType w)
f (SemiRingBase sr)
e))
        BVFlavorRepr fv
SR.BVBitsRepr  -> Domain w -> SRAbsValue (SemiRingBV BVBits w)
forall (w :: Nat).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVBits w)
SRAbsBVXor (BVDomain w -> Domain w
forall (w :: Nat). BVDomain w -> Domain w
BVD.asXorDomain (f (BaseBVType w) -> AbstractValue (BaseBVType w)
forall (f :: BaseType -> Type) (tp :: BaseType).
HasAbsValue f =>
f tp -> AbstractValue tp
AD.getAbsValue f (BaseBVType w)
f (SemiRingBase sr)
e))

abstractScalar ::
  SR.SemiRingRepr sr -> SR.Coefficient sr -> SRAbsValue sr
abstractScalar :: SemiRingRepr sr -> Coefficient sr -> SRAbsValue sr
abstractScalar SemiRingRepr sr
sr Coefficient sr
c =
  case SemiRingRepr sr
sr of
    SemiRingRepr sr
SR.SemiRingIntegerRepr -> ValueRange Integer -> SRAbsValue SemiRingInteger
SRAbsIntAdd (Integer -> ValueRange Integer
forall tp. tp -> ValueRange tp
AD.SingleRange Integer
Coefficient sr
c)
    SemiRingRepr sr
SR.SemiRingRealRepr    -> RealAbstractValue -> SRAbsValue SemiRingReal
SRAbsRealAdd (Rational -> RealAbstractValue
AD.ravSingle Rational
Coefficient sr
c)
    SR.SemiRingBVRepr BVFlavorRepr fv
fv NatRepr w
w ->
      case BVFlavorRepr fv
fv of
        BVFlavorRepr fv
SR.BVArithRepr -> Domain w -> SRAbsValue (SemiRingBV BVArith w)
forall (w :: Nat).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVArith w)
SRAbsBVAdd (NatRepr w -> Integer -> Domain w
forall (w :: Nat).
(HasCallStack, 1 <= w) =>
NatRepr w -> Integer -> Domain w
A.singleton NatRepr w
w (BV w -> Integer
forall (w :: Nat). BV w -> Integer
BV.asUnsigned BV w
Coefficient sr
c))
        BVFlavorRepr fv
SR.BVBitsRepr  -> Domain w -> SRAbsValue (SemiRingBV BVBits w)
forall (w :: Nat).
(1 <= w) =>
Domain w -> SRAbsValue (SemiRingBV BVBits w)
SRAbsBVXor (NatRepr w -> Integer -> Domain w
forall (w :: Nat). NatRepr w -> Integer -> Domain w
X.singleton NatRepr w
w (BV w -> Integer
forall (w :: Nat). BV w -> Integer
BV.asUnsigned BV w
Coefficient sr
c))

fromSRAbsValue ::
  SRAbsValue sr -> AD.AbstractValue (SR.SemiRingBase sr)
fromSRAbsValue :: SRAbsValue sr -> AbstractValue (SemiRingBase sr)
fromSRAbsValue SRAbsValue sr
v =
  case SRAbsValue sr
v of
    SRAbsIntAdd  ValueRange Integer
x -> AbstractValue (SemiRingBase sr)
ValueRange Integer
x
    SRAbsRealAdd RealAbstractValue
x -> AbstractValue (SemiRingBase sr)
RealAbstractValue
x
    SRAbsBVAdd   Domain w
x -> Domain w -> BVDomain w
forall (w :: Nat). Domain w -> BVDomain w
BVD.BVDArith Domain w
x
    SRAbsBVXor   Domain w
x -> Domain w -> BVDomain w
forall (w :: Nat). Domain w -> BVDomain w
BVD.fromXorDomain Domain w
x

--------------------------------------------------------------------------------

type Tm f = (HashableF f, OrdF f, AD.HasAbsValue f)

newtype WrapF (f :: BaseType -> Type) (i :: SR.SemiRing) = WrapF (f (SR.SemiRingBase i))

instance OrdF f => Ord (WrapF f i) where
  compare :: WrapF f i -> WrapF f i -> Ordering
compare (WrapF f (SemiRingBase i)
x) (WrapF f (SemiRingBase i)
y) = OrderingF (SemiRingBase i) (SemiRingBase i) -> Ordering
forall k (x :: k) (y :: k). OrderingF x y -> Ordering
toOrdering (OrderingF (SemiRingBase i) (SemiRingBase i) -> Ordering)
-> OrderingF (SemiRingBase i) (SemiRingBase i) -> Ordering
forall a b. (a -> b) -> a -> b
$ f (SemiRingBase i)
-> f (SemiRingBase i)
-> OrderingF (SemiRingBase i) (SemiRingBase i)
forall k (ktp :: k -> Type) (x :: k) (y :: k).
OrdF ktp =>
ktp x -> ktp y -> OrderingF x y
compareF f (SemiRingBase i)
x f (SemiRingBase i)
y

instance TestEquality f => Eq (WrapF f i) where
  (WrapF f (SemiRingBase i)
x) == :: WrapF f i -> WrapF f i -> Bool
== (WrapF f (SemiRingBase i)
y) = Maybe (SemiRingBase i :~: SemiRingBase i) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (SemiRingBase i :~: SemiRingBase i) -> Bool)
-> Maybe (SemiRingBase i :~: SemiRingBase i) -> Bool
forall a b. (a -> b) -> a -> b
$ f (SemiRingBase i)
-> f (SemiRingBase i) -> Maybe (SemiRingBase i :~: SemiRingBase i)
forall k (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality f (SemiRingBase i)
x f (SemiRingBase i)
y

instance HashableF f => Hashable (WrapF f i) where
  hashWithSalt :: Int -> WrapF f i -> Int
hashWithSalt Int
s (WrapF f (SemiRingBase i)
x) = Int -> f (SemiRingBase i) -> Int
forall k (f :: k -> Type) (tp :: k).
HashableF f =>
Int -> f tp -> Int
hashWithSaltF Int
s f (SemiRingBase i)
x

traverseWrap :: Functor m => (f (SR.SemiRingBase i) -> m (g (SR.SemiRingBase i))) -> WrapF f i -> m (WrapF g i)
traverseWrap :: (f (SemiRingBase i) -> m (g (SemiRingBase i)))
-> WrapF f i -> m (WrapF g i)
traverseWrap f (SemiRingBase i) -> m (g (SemiRingBase i))
f (WrapF f (SemiRingBase i)
x) = g (SemiRingBase i) -> WrapF g i
forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF (g (SemiRingBase i) -> WrapF g i)
-> m (g (SemiRingBase i)) -> m (WrapF g i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> f (SemiRingBase i) -> m (g (SemiRingBase i))
f f (SemiRingBase i)
x

-- | The annotation type used for the annotated map. It consists of
-- the hash value and the abstract domain representation of type @d@
-- for each submap.
data Note sr = Note !IncrHash !(SRAbsValue sr)

instance Semigroup (Note sr) where
  Note IncrHash
h1 SRAbsValue sr
d1 <> :: Note sr -> Note sr -> Note sr
<> Note IncrHash
h2 SRAbsValue sr
d2 = IncrHash -> SRAbsValue sr -> Note sr
forall (sr :: SemiRing). IncrHash -> SRAbsValue sr -> Note sr
Note (IncrHash
h1 IncrHash -> IncrHash -> IncrHash
forall a. Semigroup a => a -> a -> a
<> IncrHash
h2) (SRAbsValue sr
d1 SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
forall a. Semigroup a => a -> a -> a
<> SRAbsValue sr
d2)

data ProdNote sr = ProdNote !IncrHash !(SRAbsValue sr)

-- | The annotation type used for the annotated map for products.
-- It consists of the hash value and the abstract domain representation
-- of type @d@ for each submap.  NOTE! that the multiplication operation
-- on abstract values is not always associative.  This, however, is
-- acceptable because all associative groupings lead to sound (but perhaps not best)
-- approximate values.

instance Semigroup (ProdNote sr) where
  ProdNote IncrHash
h1 SRAbsValue sr
d1 <> :: ProdNote sr -> ProdNote sr -> ProdNote sr
<> ProdNote IncrHash
h2 SRAbsValue sr
d2 = IncrHash -> SRAbsValue sr -> ProdNote sr
forall (sr :: SemiRing). IncrHash -> SRAbsValue sr -> ProdNote sr
ProdNote (IncrHash
h1 IncrHash -> IncrHash -> IncrHash
forall a. Semigroup a => a -> a -> a
<> IncrHash
h2) (SRAbsValue sr
d1 SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
forall (sr :: SemiRing).
SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
.** SRAbsValue sr
d2)

-- | Construct the annotation for a single map entry.
mkNote ::
  (HashableF f, AD.HasAbsValue f) =>
  SR.SemiRingRepr sr -> SR.Coefficient sr -> f (SR.SemiRingBase sr) -> Note sr
mkNote :: SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t = IncrHash -> SRAbsValue sr -> Note sr
forall (sr :: SemiRing). IncrHash -> SRAbsValue sr -> Note sr
Note (Int -> IncrHash
mkIncrHash Int
h) SRAbsValue sr
d
  where
    h :: Int
h = SemiRingRepr sr -> Int -> Coefficient sr -> Int
forall (sr :: SemiRing).
SemiRingRepr sr -> Int -> Coefficient sr -> Int
SR.sr_hashWithSalt SemiRingRepr sr
sr (f (SemiRingBase sr) -> Int
forall k (f :: k -> Type) (tp :: k). HashableF f => f tp -> Int
hashF f (SemiRingBase sr)
t) Coefficient sr
c
    d :: SRAbsValue sr
d = SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SRAbsValue sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
HasAbsValue f =>
SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SRAbsValue sr
abstractTerm SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t

mkProdNote ::
  (HashableF f, AD.HasAbsValue f) =>
  SR.SemiRingRepr sr ->
  SR.Occurrence sr ->
  f (SR.SemiRingBase sr) ->
  ProdNote sr
mkProdNote :: SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdNote sr
mkProdNote SemiRingRepr sr
sr Occurrence sr
occ f (SemiRingBase sr)
t = IncrHash -> SRAbsValue sr -> ProdNote sr
forall (sr :: SemiRing). IncrHash -> SRAbsValue sr -> ProdNote sr
ProdNote (Int -> IncrHash
mkIncrHash Int
h) SRAbsValue sr
d
  where
    h :: Int
h = SemiRingRepr sr -> Int -> Occurrence sr -> Int
forall (sr :: SemiRing).
SemiRingRepr sr -> Int -> Occurrence sr -> Int
SR.occ_hashWithSalt SemiRingRepr sr
sr (f (SemiRingBase sr) -> Int
forall k (f :: k -> Type) (tp :: k). HashableF f => f tp -> Int
hashF f (SemiRingBase sr)
t) Occurrence sr
occ
    v :: SRAbsValue sr
v = SemiRingRepr sr -> f (SemiRingBase sr) -> SRAbsValue sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
HasAbsValue f =>
SemiRingRepr sr -> f (SemiRingBase sr) -> SRAbsValue sr
abstractVal SemiRingRepr sr
sr f (SemiRingBase sr)
t
    power :: Integer
power = Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SemiRingRepr sr -> Occurrence sr -> Natural
forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Natural
SR.occ_count SemiRingRepr sr
sr Occurrence sr
occ)
    d :: SRAbsValue sr
d = Integer -> SRAbsValue sr -> SRAbsValue sr
go (Integer
power Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1) SRAbsValue sr
v

    go :: Integer -> SRAbsValue sr -> SRAbsValue sr
go (Integer
n::Integer) SRAbsValue sr
x
      | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0     = Integer -> SRAbsValue sr -> SRAbsValue sr
go (Integer
nInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
1) (SRAbsValue sr
v SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
forall (sr :: SemiRing).
SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
.** SRAbsValue sr
x)
      | Bool
otherwise = SRAbsValue sr
x

type SumMap f sr  = AnnotatedMap (WrapF f sr) (Note sr) (SR.Coefficient sr)
type ProdMap f sr = AnnotatedMap (WrapF f sr) (ProdNote sr) (SR.Occurrence sr)

insertSumMap ::
  Tm f =>
  SR.SemiRingRepr sr ->
  SR.Coefficient sr -> f (SR.SemiRingBase sr) -> SumMap f sr -> SumMap f sr
insertSumMap :: SemiRingRepr sr
-> Coefficient sr
-> f (SemiRingBase sr)
-> SumMap f sr
-> SumMap f sr
insertSumMap SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t = (Maybe (Note sr, Coefficient sr)
 -> Maybe (Note sr, Coefficient sr))
-> WrapF f sr -> SumMap f sr -> SumMap f sr
forall k v a.
(Ord k, Semigroup v) =>
(Maybe (v, a) -> Maybe (v, a))
-> k -> AnnotatedMap k v a -> AnnotatedMap k v a
AM.alter Maybe (Note sr, Coefficient sr) -> Maybe (Note sr, Coefficient sr)
f (f (SemiRingBase sr) -> WrapF f sr
forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF f (SemiRingBase sr)
t)
  where
    f :: Maybe (Note sr, Coefficient sr) -> Maybe (Note sr, Coefficient sr)
f Maybe (Note sr, Coefficient sr)
Nothing = (Note sr, Coefficient sr) -> Maybe (Note sr, Coefficient sr)
forall a. a -> Maybe a
Just (SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t, Coefficient sr
c)
    f (Just (Note sr
_, Coefficient sr
c0))
      | SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) Coefficient sr
c' = Maybe (Note sr, Coefficient sr)
forall a. Maybe a
Nothing
      | Bool
otherwise = (Note sr, Coefficient sr) -> Maybe (Note sr, Coefficient sr)
forall a. a -> Maybe a
Just (SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
c' f (SemiRingBase sr)
t, Coefficient sr
c')
      where c' :: Coefficient sr
c' = SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.add SemiRingRepr sr
sr Coefficient sr
c0 Coefficient sr
c

singletonSumMap ::
  Tm f =>
  SR.SemiRingRepr sr ->
  SR.Coefficient sr -> f (SR.SemiRingBase sr) -> SumMap f sr
singletonSumMap :: SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SumMap f sr
singletonSumMap SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t = WrapF f sr -> Note sr -> Coefficient sr -> SumMap f sr
forall k v a.
(Ord k, Semigroup v) =>
k -> v -> a -> AnnotatedMap k v a
AM.singleton (f (SemiRingBase sr) -> WrapF f sr
forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF f (SemiRingBase sr)
t) (SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t) Coefficient sr
c

singletonProdMap ::
  Tm f =>
  SR.SemiRingRepr sr ->
  SR.Occurrence sr ->
  f (SR.SemiRingBase sr) ->
  ProdMap f sr
singletonProdMap :: SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdMap f sr
singletonProdMap SemiRingRepr sr
sr Occurrence sr
occ f (SemiRingBase sr)
t = WrapF f sr -> ProdNote sr -> Occurrence sr -> ProdMap f sr
forall k v a.
(Ord k, Semigroup v) =>
k -> v -> a -> AnnotatedMap k v a
AM.singleton (f (SemiRingBase sr) -> WrapF f sr
forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF f (SemiRingBase sr)
t) (SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdNote sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdNote sr
mkProdNote SemiRingRepr sr
sr Occurrence sr
occ f (SemiRingBase sr)
t) Occurrence sr
occ

fromListSumMap ::
  Tm f =>
  SR.SemiRingRepr sr ->
  [(f (SR.SemiRingBase sr), SR.Coefficient sr)] -> SumMap f sr
fromListSumMap :: SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)] -> SumMap f sr
fromListSumMap SemiRingRepr sr
_ [] = SumMap f sr
forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty
fromListSumMap SemiRingRepr sr
sr ((f (SemiRingBase sr)
t, Coefficient sr
c) : [(f (SemiRingBase sr), Coefficient sr)]
xs) = SemiRingRepr sr
-> Coefficient sr
-> f (SemiRingBase sr)
-> SumMap f sr
-> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr
-> f (SemiRingBase sr)
-> SumMap f sr
-> SumMap f sr
insertSumMap SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t (SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)] -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)] -> SumMap f sr
fromListSumMap SemiRingRepr sr
sr [(f (SemiRingBase sr), Coefficient sr)]
xs)

toListSumMap :: SumMap f sr -> [(f (SR.SemiRingBase sr), SR.Coefficient sr)]
toListSumMap :: SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap SumMap f sr
am = [ (f (SemiRingBase sr)
t, Coefficient sr
c) | (WrapF f (SemiRingBase sr)
t, Coefficient sr
c) <- SumMap f sr -> [(WrapF f sr, Coefficient sr)]
forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList SumMap f sr
am ]

-- | A weighted sum of semiring values.  Mathematically, this represents
--   an affine operation on the underlying expressions.
data WeightedSum (f :: BaseType -> Type) (sr :: SR.SemiRing)
   = WeightedSum { WeightedSum f sr -> SumMap f sr
_sumMap     :: !(SumMap f sr)
                 , WeightedSum f sr -> Coefficient sr
_sumOffset  :: !(SR.Coefficient sr)
                 , WeightedSum f sr -> SemiRingRepr sr
sumRepr     :: !(SR.SemiRingRepr sr)
                     -- ^ Runtime representation of the semiring for this sum.
                 }

-- | A product of semiring values.
data SemiRingProduct (f :: BaseType -> Type) (sr :: SR.SemiRing)
   = SemiRingProduct { SemiRingProduct f sr -> ProdMap f sr
_prodMap  :: !(ProdMap f sr)
                     , SemiRingProduct f sr -> SemiRingRepr sr
prodRepr  :: !(SR.SemiRingRepr sr)
                         -- ^ Runtime representation of the semiring for this product
                     }

-- | Return the hash of the 'SumMap' part of the 'WeightedSum'.
sumMapHash :: OrdF f => WeightedSum f sr -> IncrHash
sumMapHash :: WeightedSum f sr -> IncrHash
sumMapHash WeightedSum f sr
x =
  case AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
-> Maybe (Note sr)
forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a -> Maybe v
AM.annotation (WeightedSum f sr
-> AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
x) of
    Maybe (Note sr)
Nothing -> IncrHash
forall a. Monoid a => a
mempty
    Just (Note IncrHash
h SRAbsValue sr
_) -> IncrHash
h

prodMapHash :: OrdF f => SemiRingProduct f sr -> IncrHash
prodMapHash :: SemiRingProduct f sr -> IncrHash
prodMapHash SemiRingProduct f sr
pd =
  case AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
-> Maybe (ProdNote sr)
forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a -> Maybe v
AM.annotation (SemiRingProduct f sr
-> AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
pd) of
    Maybe (ProdNote sr)
Nothing -> IncrHash
forall a. Monoid a => a
mempty
    Just (ProdNote IncrHash
h SRAbsValue sr
_) -> IncrHash
h

sumAbsValue :: OrdF f => WeightedSum f sr -> AD.AbstractValue (SR.SemiRingBase sr)
sumAbsValue :: WeightedSum f sr -> AbstractValue (SemiRingBase sr)
sumAbsValue WeightedSum f sr
wsum =
  SRAbsValue sr -> AbstractValue (SemiRingBase sr)
forall (sr :: SemiRing).
SRAbsValue sr -> AbstractValue (SemiRingBase sr)
fromSRAbsValue (SRAbsValue sr -> AbstractValue (SemiRingBase sr))
-> SRAbsValue sr -> AbstractValue (SemiRingBase sr)
forall a b. (a -> b) -> a -> b
$
  case AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
-> Maybe (Note sr)
forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a -> Maybe v
AM.annotation (WeightedSum f sr
-> AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
wsum) of
    Maybe (Note sr)
Nothing         -> SRAbsValue sr
absOffset
    Just (Note IncrHash
_ SRAbsValue sr
v) -> SRAbsValue sr
absOffset SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
forall a. Semigroup a => a -> a -> a
<> SRAbsValue sr
v
  where
    absOffset :: SRAbsValue sr
absOffset = SemiRingRepr sr -> Coefficient sr -> SRAbsValue sr
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> SRAbsValue sr
abstractScalar (WeightedSum f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
wsum) (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
wsum)

instance OrdF f => TestEquality (SemiRingProduct f) where
  testEquality :: SemiRingProduct f a -> SemiRingProduct f b -> Maybe (a :~: b)
testEquality SemiRingProduct f a
x SemiRingProduct f b
y
    | SemiRingProduct f a -> IncrHash
forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
SemiRingProduct f sr -> IncrHash
prodMapHash SemiRingProduct f a
x IncrHash -> IncrHash -> Bool
forall a. Eq a => a -> a -> Bool
/= SemiRingProduct f b -> IncrHash
forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
SemiRingProduct f sr -> IncrHash
prodMapHash SemiRingProduct f b
y = Maybe (a :~: b)
forall a. Maybe a
Nothing
    | Bool
otherwise =
        do a :~: b
Refl <- SemiRingRepr a -> SemiRingRepr b -> Maybe (a :~: b)
forall k (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (SemiRingProduct f a -> SemiRingRepr a
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f a
x) (SemiRingProduct f b -> SemiRingRepr b
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f b
y)
           Bool -> Maybe () -> Maybe ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless ((Occurrence a -> Occurrence a -> Bool)
-> AnnotatedMap (WrapF f a) (ProdNote a) (Occurrence a)
-> AnnotatedMap (WrapF f a) (ProdNote a) (Occurrence a)
-> Bool
forall k a v.
Eq k =>
(a -> a -> Bool)
-> AnnotatedMap k v a -> AnnotatedMap k v a -> Bool
AM.eqBy (SemiRingRepr a -> Occurrence a -> Occurrence a -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Bool
SR.occ_eq (SemiRingProduct f a -> SemiRingRepr a
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f a
x)) (SemiRingProduct f a
-> AnnotatedMap (WrapF f a) (ProdNote a) (Occurrence a)
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f a
x) (SemiRingProduct f b -> ProdMap f b
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f b
y)) Maybe ()
forall a. Maybe a
Nothing
           (a :~: a) -> Maybe (a :~: a)
forall (m :: Type -> Type) a. Monad m => a -> m a
return a :~: a
forall k (a :: k). a :~: a
Refl

instance OrdF f => TestEquality (WeightedSum f) where
  testEquality :: WeightedSum f a -> WeightedSum f b -> Maybe (a :~: b)
testEquality WeightedSum f a
x WeightedSum f b
y
    | WeightedSum f a -> IncrHash
forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
WeightedSum f sr -> IncrHash
sumMapHash WeightedSum f a
x IncrHash -> IncrHash -> Bool
forall a. Eq a => a -> a -> Bool
/= WeightedSum f b -> IncrHash
forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
WeightedSum f sr -> IncrHash
sumMapHash WeightedSum f b
y = Maybe (a :~: b)
forall a. Maybe a
Nothing
    | Bool
otherwise =
         do a :~: b
Refl <- SemiRingRepr a -> SemiRingRepr b -> Maybe (a :~: b)
forall k (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (WeightedSum f a -> SemiRingRepr a
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f a
x) (WeightedSum f b -> SemiRingRepr b
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f b
y)
            Bool -> Maybe () -> Maybe ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SemiRingRepr a -> Coefficient a -> Coefficient a -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq (WeightedSum f a -> SemiRingRepr a
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f a
x) (WeightedSum f a -> Coefficient a
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f a
x) (WeightedSum f b -> Coefficient b
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f b
y)) Maybe ()
forall a. Maybe a
Nothing
            Bool -> Maybe () -> Maybe ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless ((Coefficient a -> Coefficient a -> Bool)
-> AnnotatedMap (WrapF f a) (Note a) (Coefficient a)
-> AnnotatedMap (WrapF f a) (Note a) (Coefficient a)
-> Bool
forall k a v.
Eq k =>
(a -> a -> Bool)
-> AnnotatedMap k v a -> AnnotatedMap k v a -> Bool
AM.eqBy (SemiRingRepr a -> Coefficient a -> Coefficient a -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq (WeightedSum f a -> SemiRingRepr a
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f a
x)) (WeightedSum f a
-> AnnotatedMap (WrapF f a) (Note a) (Coefficient a)
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f a
x) (WeightedSum f b -> SumMap f b
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f b
y)) Maybe ()
forall a. Maybe a
Nothing
            (a :~: a) -> Maybe (a :~: a)
forall (m :: Type -> Type) a. Monad m => a -> m a
return a :~: a
forall k (a :: k). a :~: a
Refl


-- | Created a weighted sum directly from a map and constant.
--
-- Note. When calling this, one should ensure map values equal to '0'
-- have been removed.
unfilteredSum ::
  SR.SemiRingRepr sr ->
  SumMap f sr ->
  SR.Coefficient sr ->
  WeightedSum f sr
unfilteredSum :: SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr SumMap f sr
m Coefficient sr
c = SumMap f sr
-> Coefficient sr -> SemiRingRepr sr -> WeightedSum f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr
-> Coefficient sr -> SemiRingRepr sr -> WeightedSum f sr
WeightedSum SumMap f sr
m Coefficient sr
c SemiRingRepr sr
sr

-- | Retrieve the mapping from terms to coefficients.
sumMap :: HashableF f => Lens' (WeightedSum f sr) (SumMap f sr)
sumMap :: Lens' (WeightedSum f sr) (SumMap f sr)
sumMap = (WeightedSum f sr -> SumMap f sr)
-> (WeightedSum f sr -> SumMap f sr -> WeightedSum f sr)
-> Lens' (WeightedSum f sr) (SumMap f sr)
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens WeightedSum f sr -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap (\WeightedSum f sr
w SumMap f sr
m -> WeightedSum f sr
w{ _sumMap :: SumMap f sr
_sumMap = SumMap f sr
m })

-- | Retrieve the constant addend of the weighted sum.
sumOffset :: Lens' (WeightedSum f sr) (SR.Coefficient sr)
sumOffset :: (Coefficient sr -> f (Coefficient sr))
-> WeightedSum f sr -> f (WeightedSum f sr)
sumOffset = (WeightedSum f sr -> Coefficient sr)
-> (WeightedSum f sr -> Coefficient sr -> WeightedSum f sr)
-> Lens
     (WeightedSum f sr)
     (WeightedSum f sr)
     (Coefficient sr)
     (Coefficient sr)
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset (\WeightedSum f sr
s Coefficient sr
v -> WeightedSum f sr
s { _sumOffset :: Coefficient sr
_sumOffset = Coefficient sr
v })

instance OrdF f => Hashable (WeightedSum f sr) where
  hashWithSalt :: Int -> WeightedSum f sr -> Int
hashWithSalt Int
s0 WeightedSum f sr
w =
    Int -> IncrHash -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt (SemiRingRepr sr -> Int -> Coefficient sr -> Int
forall (sr :: SemiRing).
SemiRingRepr sr -> Int -> Coefficient sr -> Int
SR.sr_hashWithSalt (WeightedSum f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
w) Int
s0 (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)) (WeightedSum f sr -> IncrHash
forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
WeightedSum f sr -> IncrHash
sumMapHash WeightedSum f sr
w)

instance OrdF f => Hashable (SemiRingProduct f sr) where
  hashWithSalt :: Int -> SemiRingProduct f sr -> Int
hashWithSalt Int
s0 SemiRingProduct f sr
w = Int -> IncrHash -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s0 (SemiRingProduct f sr -> IncrHash
forall (f :: BaseType -> Type) (sr :: SemiRing).
OrdF f =>
SemiRingProduct f sr -> IncrHash
prodMapHash SemiRingProduct f sr
w)

-- | Attempt to parse a weighted sum as a constant.
asConstant :: WeightedSum f sr -> Maybe (SR.Coefficient sr)
asConstant :: WeightedSum f sr -> Maybe (Coefficient sr)
asConstant WeightedSum f sr
w
  | AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr) -> Bool
forall k v a. AnnotatedMap k v a -> Bool
AM.null (WeightedSum f sr
-> AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w) = Coefficient sr -> Maybe (Coefficient sr)
forall a. a -> Maybe a
Just (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)
  | Bool
otherwise = Maybe (Coefficient sr)
forall a. Maybe a
Nothing

-- | Return true if a weighted sum is equal to constant 0.
isZero :: SR.SemiRingRepr sr -> WeightedSum f sr -> Bool
isZero :: SemiRingRepr sr -> WeightedSum f sr -> Bool
isZero SemiRingRepr sr
sr WeightedSum f sr
s =
   case WeightedSum f sr -> Maybe (Coefficient sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Maybe (Coefficient sr)
asConstant WeightedSum f sr
s of
     Just Coefficient sr
c  -> SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Ordering
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Ordering
SR.sr_compare SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) Coefficient sr
c Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ
     Maybe (Coefficient sr)
Nothing -> Bool
False

-- | Attempt to parse a weighted sum as a single expression with a coefficient and offset.
--   @asAffineVar w = Just (c,r,o)@ when @denotation(w) = c*r + o@.
asAffineVar :: WeightedSum f sr -> Maybe (SR.Coefficient sr, f (SR.SemiRingBase sr), SR.Coefficient sr)
asAffineVar :: WeightedSum f sr
-> Maybe (Coefficient sr, f (SemiRingBase sr), Coefficient sr)
asAffineVar WeightedSum f sr
w
  | [(WrapF f (SemiRingBase sr)
r, Coefficient sr
c)] <- AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
-> [(WrapF f sr, Coefficient sr)]
forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (WeightedSum f sr
-> AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w)
  = (Coefficient sr, f (SemiRingBase sr), Coefficient sr)
-> Maybe (Coefficient sr, f (SemiRingBase sr), Coefficient sr)
forall a. a -> Maybe a
Just (Coefficient sr
c,f (SemiRingBase sr)
r,WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)

  | Bool
otherwise
  = Maybe (Coefficient sr, f (SemiRingBase sr), Coefficient sr)
forall a. Maybe a
Nothing

-- | Attempt to parse weighted sum as a single expression with a coefficient.
--   @asWeightedVar w = Just (c,r)@ when @denotation(w) = c*r@.
asWeightedVar :: WeightedSum f sr -> Maybe (SR.Coefficient sr, f (SR.SemiRingBase sr))
asWeightedVar :: WeightedSum f sr -> Maybe (Coefficient sr, f (SemiRingBase sr))
asWeightedVar WeightedSum f sr
w
  | [(WrapF f (SemiRingBase sr)
r, Coefficient sr
c)] <- AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
-> [(WrapF f sr, Coefficient sr)]
forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (WeightedSum f sr
-> AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w)
  , let sr :: SemiRingRepr sr
sr = WeightedSum f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
w
  , SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)
  = (Coefficient sr, f (SemiRingBase sr))
-> Maybe (Coefficient sr, f (SemiRingBase sr))
forall a. a -> Maybe a
Just (Coefficient sr
c,f (SemiRingBase sr)
r)

  | Bool
otherwise
  = Maybe (Coefficient sr, f (SemiRingBase sr))
forall a. Maybe a
Nothing

-- | Attempt to parse a weighted sum as a single expression.
--   @asVar w = Just r@ when @denotation(w) = r@
asVar :: WeightedSum f sr -> Maybe (f (SR.SemiRingBase sr))
asVar :: WeightedSum f sr -> Maybe (f (SemiRingBase sr))
asVar WeightedSum f sr
w
  | [(WrapF f (SemiRingBase sr)
r, Coefficient sr
c)] <- AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
-> [(WrapF f sr, Coefficient sr)]
forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (WeightedSum f sr
-> AnnotatedMap (WrapF f sr) (Note sr) (Coefficient sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w)
  , let sr :: SemiRingRepr sr
sr = WeightedSum f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
w
  , SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one SemiRingRepr sr
sr) Coefficient sr
c
  , SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)
  = f (SemiRingBase sr) -> Maybe (f (SemiRingBase sr))
forall a. a -> Maybe a
Just f (SemiRingBase sr)
r

  | Bool
otherwise
  = Maybe (f (SemiRingBase sr))
forall a. Maybe a
Nothing

-- | Create a sum from a constant coefficient value.
constant :: Tm f => SR.SemiRingRepr sr -> SR.Coefficient sr -> WeightedSum f sr
constant :: SemiRingRepr sr -> Coefficient sr -> WeightedSum f sr
constant SemiRingRepr sr
sr Coefficient sr
c = SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr SumMap f sr
forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty Coefficient sr
c

-- | Traverse the expressions in a weighted sum.
traverseVars :: forall k j m sr.
  (Applicative m, Tm k) =>
  (j (SR.SemiRingBase sr) -> m (k (SR.SemiRingBase sr))) ->
  WeightedSum j sr ->
  m (WeightedSum k sr)
traverseVars :: (j (SemiRingBase sr) -> m (k (SemiRingBase sr)))
-> WeightedSum j sr -> m (WeightedSum k sr)
traverseVars j (SemiRingBase sr) -> m (k (SemiRingBase sr))
f WeightedSum j sr
w =
  (\[(k (SemiRingBase sr), Coefficient sr)]
tms -> SemiRingRepr sr
-> [(k (SemiRingBase sr), Coefficient sr)]
-> Coefficient sr
-> WeightedSum k sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)]
-> Coefficient sr
-> WeightedSum f sr
fromTerms SemiRingRepr sr
sr [(k (SemiRingBase sr), Coefficient sr)]
tms (WeightedSum j sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum j sr
w)) ([(k (SemiRingBase sr), Coefficient sr)] -> WeightedSum k sr)
-> m [(k (SemiRingBase sr), Coefficient sr)]
-> m (WeightedSum k sr)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$>
  ((j (SemiRingBase sr), Coefficient sr)
 -> m (k (SemiRingBase sr), Coefficient sr))
-> [(j (SemiRingBase sr), Coefficient sr)]
-> m [(k (SemiRingBase sr), Coefficient sr)]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((j (SemiRingBase sr) -> m (k (SemiRingBase sr)))
-> (j (SemiRingBase sr), Coefficient sr)
-> m (k (SemiRingBase sr), Coefficient sr)
forall s t a b. Field1 s t a b => Lens s t a b
_1 j (SemiRingBase sr) -> m (k (SemiRingBase sr))
f) (SumMap j sr -> [(j (SemiRingBase sr), Coefficient sr)]
forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (WeightedSum j sr -> SumMap j sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum j sr
w))
  where sr :: SemiRingRepr sr
sr = WeightedSum j sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum j sr
w

-- | Traverse the coefficients in a weighted sum.
traverseCoeffs :: forall m f sr.
  (Applicative m, Tm f) =>
  (SR.Coefficient sr -> m (SR.Coefficient sr)) ->
  WeightedSum f sr ->
  m (WeightedSum f sr)
traverseCoeffs :: (Coefficient sr -> m (Coefficient sr))
-> WeightedSum f sr -> m (WeightedSum f sr)
traverseCoeffs Coefficient sr -> m (Coefficient sr)
f WeightedSum f sr
w =
  SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (SumMap f sr -> Coefficient sr -> WeightedSum f sr)
-> m (SumMap f sr) -> m (Coefficient sr -> WeightedSum f sr)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (WrapF f sr
 -> Note sr
 -> Coefficient sr
 -> m (Maybe (Note sr, Coefficient sr)))
-> SumMap f sr -> m (SumMap f sr)
forall (f :: Type -> Type) k v1 v2 a1 a2.
(Applicative f, Ord k, Semigroup v1, Semigroup v2) =>
(k -> v1 -> a1 -> f (Maybe (v2, a2)))
-> AnnotatedMap k v1 a1 -> f (AnnotatedMap k v2 a2)
AM.traverseMaybeWithKey WrapF f sr
-> Note sr -> Coefficient sr -> m (Maybe (Note sr, Coefficient sr))
g (WeightedSum f sr -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w) m (Coefficient sr -> WeightedSum f sr)
-> m (Coefficient sr) -> m (WeightedSum f sr)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Coefficient sr -> m (Coefficient sr)
f (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w)
  where
    sr :: SemiRingRepr sr
sr = WeightedSum f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
w
    g :: WrapF f sr
-> Note sr -> Coefficient sr -> m (Maybe (Note sr, Coefficient sr))
g (WrapF f (SemiRingBase sr)
t) Note sr
_ Coefficient sr
c = f (SemiRingBase sr)
-> Coefficient sr -> Maybe (Note sr, Coefficient sr)
mk f (SemiRingBase sr)
t (Coefficient sr -> Maybe (Note sr, Coefficient sr))
-> m (Coefficient sr) -> m (Maybe (Note sr, Coefficient sr))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Coefficient sr -> m (Coefficient sr)
f Coefficient sr
c
    mk :: f (SemiRingBase sr)
-> Coefficient sr -> Maybe (Note sr, Coefficient sr)
mk f (SemiRingBase sr)
t Coefficient sr
c = if SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) Coefficient sr
c then Maybe (Note sr, Coefficient sr)
forall a. Maybe a
Nothing else (Note sr, Coefficient sr) -> Maybe (Note sr, Coefficient sr)
forall a. a -> Maybe a
Just (SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
c f (SemiRingBase sr)
t, Coefficient sr
c)

-- | Traverse the expressions in a product.
traverseProdVars :: forall k j m sr.
  (Applicative m, Tm k) =>
  (j (SR.SemiRingBase sr) -> m (k (SR.SemiRingBase sr))) ->
  SemiRingProduct j sr ->
  m (SemiRingProduct k sr)
traverseProdVars :: (j (SemiRingBase sr) -> m (k (SemiRingBase sr)))
-> SemiRingProduct j sr -> m (SemiRingProduct k sr)
traverseProdVars j (SemiRingBase sr) -> m (k (SemiRingBase sr))
f SemiRingProduct j sr
pd =
  SemiRingRepr sr -> ProdMap k sr -> SemiRingProduct k sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
HashableF f =>
SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd SemiRingRepr sr
sr (ProdMap k sr -> SemiRingProduct k sr)
-> ([(WrapF k sr, Occurrence sr)] -> ProdMap k sr)
-> [(WrapF k sr, Occurrence sr)]
-> SemiRingProduct k sr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(WrapF k sr, Occurrence sr)] -> ProdMap k sr
rebuild ([(WrapF k sr, Occurrence sr)] -> SemiRingProduct k sr)
-> m [(WrapF k sr, Occurrence sr)] -> m (SemiRingProduct k sr)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$>
    ((WrapF j sr, Occurrence sr) -> m (WrapF k sr, Occurrence sr))
-> [(WrapF j sr, Occurrence sr)] -> m [(WrapF k sr, Occurrence sr)]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((WrapF j sr -> m (WrapF k sr))
-> (WrapF j sr, Occurrence sr) -> m (WrapF k sr, Occurrence sr)
forall s t a b. Field1 s t a b => Lens s t a b
_1 ((j (SemiRingBase sr) -> m (k (SemiRingBase sr)))
-> WrapF j sr -> m (WrapF k sr)
forall (m :: Type -> Type) (f :: BaseType -> Type) (i :: SemiRing)
       (g :: BaseType -> Type).
Functor m =>
(f (SemiRingBase i) -> m (g (SemiRingBase i)))
-> WrapF f i -> m (WrapF g i)
traverseWrap j (SemiRingBase sr) -> m (k (SemiRingBase sr))
f)) (AnnotatedMap (WrapF j sr) (ProdNote sr) (Occurrence sr)
-> [(WrapF j sr, Occurrence sr)]
forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (SemiRingProduct j sr
-> AnnotatedMap (WrapF j sr) (ProdNote sr) (Occurrence sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct j sr
pd))
 where
  sr :: SemiRingRepr sr
sr = SemiRingProduct j sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct j sr
pd
  rebuild :: [(WrapF k sr, Occurrence sr)] -> ProdMap k sr
rebuild = (ProdMap k sr -> (WrapF k sr, Occurrence sr) -> ProdMap k sr)
-> ProdMap k sr -> [(WrapF k sr, Occurrence sr)] -> ProdMap k sr
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\ProdMap k sr
m (WrapF k (SemiRingBase sr)
t, Occurrence sr
occ) -> WrapF k sr
-> ProdNote sr -> Occurrence sr -> ProdMap k sr -> ProdMap k sr
forall k v a.
(Ord k, Semigroup v) =>
k -> v -> a -> AnnotatedMap k v a -> AnnotatedMap k v a
AM.insert (k (SemiRingBase sr) -> WrapF k sr
forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF k (SemiRingBase sr)
t) (SemiRingRepr sr
-> Occurrence sr -> k (SemiRingBase sr) -> ProdNote sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdNote sr
mkProdNote SemiRingRepr sr
sr Occurrence sr
occ k (SemiRingBase sr)
t) Occurrence sr
occ ProdMap k sr
m) ProdMap k sr
forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty


-- | This returns a variable times a constant.
scaledVar :: Tm f => SR.SemiRingRepr sr -> SR.Coefficient sr -> f (SR.SemiRingBase sr) -> WeightedSum f sr
scaledVar :: SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> WeightedSum f sr
scaledVar SemiRingRepr sr
sr Coefficient sr
s f (SemiRingBase sr)
t
  | SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) Coefficient sr
s = SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr SumMap f sr
forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)
  | Bool
otherwise = SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SumMap f sr
singletonSumMap SemiRingRepr sr
sr Coefficient sr
s f (SemiRingBase sr)
t) (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)

-- | Create a weighted sum corresponding to the given variable.
var :: Tm f => SR.SemiRingRepr sr -> f (SR.SemiRingBase sr) -> WeightedSum f sr
var :: SemiRingRepr sr -> f (SemiRingBase sr) -> WeightedSum f sr
var SemiRingRepr sr
sr f (SemiRingBase sr)
t = SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr -> f (SemiRingBase sr) -> SumMap f sr
singletonSumMap SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one SemiRingRepr sr
sr) f (SemiRingBase sr)
t) (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)

-- | Add two sums, collecting terms as necessary and deleting terms whose
--   coefficients sum to 0.
add ::
  Tm f =>
  SR.SemiRingRepr sr ->
  WeightedSum f sr ->
  WeightedSum f sr ->
  WeightedSum f sr
add :: SemiRingRepr sr
-> WeightedSum f sr -> WeightedSum f sr -> WeightedSum f sr
add SemiRingRepr sr
sr WeightedSum f sr
x WeightedSum f sr
y = SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr SumMap f sr
zm Coefficient sr
zc
  where
    merge :: WrapF f sr
-> Coefficient sr
-> Coefficient sr
-> Maybe (Note sr, Coefficient sr)
merge (WrapF f (SemiRingBase sr)
k) Coefficient sr
u Coefficient sr
v | SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr Coefficient sr
r (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) = Maybe (Note sr, Coefficient sr)
forall a. Maybe a
Nothing
                        | Bool
otherwise               = (Note sr, Coefficient sr) -> Maybe (Note sr, Coefficient sr)
forall a. a -> Maybe a
Just (SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
r f (SemiRingBase sr)
k, Coefficient sr
r)
      where r :: Coefficient sr
r = SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.add SemiRingRepr sr
sr Coefficient sr
u Coefficient sr
v
    zm :: SumMap f sr
zm = (WrapF f sr
 -> Coefficient sr
 -> Coefficient sr
 -> Maybe (Note sr, Coefficient sr))
-> SumMap f sr -> SumMap f sr -> SumMap f sr
forall k v a.
(Ord k, Semigroup v) =>
(k -> a -> a -> Maybe (v, a))
-> AnnotatedMap k v a -> AnnotatedMap k v a -> AnnotatedMap k v a
AM.unionWithKeyMaybe WrapF f sr
-> Coefficient sr
-> Coefficient sr
-> Maybe (Note sr, Coefficient sr)
merge (WeightedSum f sr -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
x) (WeightedSum f sr -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
y)
    zc :: Coefficient sr
zc = SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.add SemiRingRepr sr
sr (WeightedSum f sr
xWeightedSum f sr
-> Getting (Coefficient sr) (WeightedSum f sr) (Coefficient sr)
-> Coefficient sr
forall s a. s -> Getting a s a -> a
^.Getting (Coefficient sr) (WeightedSum f sr) (Coefficient sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset) (WeightedSum f sr
yWeightedSum f sr
-> Getting (Coefficient sr) (WeightedSum f sr) (Coefficient sr)
-> Coefficient sr
forall s a. s -> Getting a s a -> a
^.Getting (Coefficient sr) (WeightedSum f sr) (Coefficient sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset)

-- | Create a weighted sum that represents the sum of two terms.
addVars ::
  Tm f =>
  SR.SemiRingRepr sr ->
  f (SR.SemiRingBase sr) ->
  f (SR.SemiRingBase sr) ->
  WeightedSum f sr
addVars :: SemiRingRepr sr
-> f (SemiRingBase sr) -> f (SemiRingBase sr) -> WeightedSum f sr
addVars SemiRingRepr sr
sr f (SemiRingBase sr)
x f (SemiRingBase sr)
y = SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)]
-> Coefficient sr
-> WeightedSum f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)]
-> Coefficient sr
-> WeightedSum f sr
fromTerms SemiRingRepr sr
sr [(f (SemiRingBase sr)
x, SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one SemiRingRepr sr
sr), (f (SemiRingBase sr)
y, SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one SemiRingRepr sr
sr)] (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)

-- | Add a variable to the sum.
addVar ::
  Tm f =>
  SR.SemiRingRepr sr ->
  WeightedSum f sr -> f (SR.SemiRingBase sr) -> WeightedSum f sr
addVar :: SemiRingRepr sr
-> WeightedSum f sr -> f (SemiRingBase sr) -> WeightedSum f sr
addVar SemiRingRepr sr
sr WeightedSum f sr
wsum f (SemiRingBase sr)
x = WeightedSum f sr
wsum { _sumMap :: SumMap f sr
_sumMap = SumMap f sr
m' }
  where m' :: SumMap f sr
m' = SemiRingRepr sr
-> Coefficient sr
-> f (SemiRingBase sr)
-> SumMap f sr
-> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Coefficient sr
-> f (SemiRingBase sr)
-> SumMap f sr
-> SumMap f sr
insertSumMap SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one SemiRingRepr sr
sr) f (SemiRingBase sr)
x (WeightedSum f sr -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
wsum)

-- | Add a constant to the sum.
addConstant :: SR.SemiRingRepr sr -> WeightedSum f sr -> SR.Coefficient sr -> WeightedSum f sr
addConstant :: SemiRingRepr sr
-> WeightedSum f sr -> Coefficient sr -> WeightedSum f sr
addConstant SemiRingRepr sr
sr WeightedSum f sr
x Coefficient sr
r = WeightedSum f sr
x WeightedSum f sr
-> (WeightedSum f sr -> WeightedSum f sr) -> WeightedSum f sr
forall a b. a -> (a -> b) -> b
& (Coefficient sr -> Identity (Coefficient sr))
-> WeightedSum f sr -> Identity (WeightedSum f sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset ((Coefficient sr -> Identity (Coefficient sr))
 -> WeightedSum f sr -> Identity (WeightedSum f sr))
-> (Coefficient sr -> Coefficient sr)
-> WeightedSum f sr
-> WeightedSum f sr
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.add SemiRingRepr sr
sr Coefficient sr
r

-- | Multiply a sum by a constant coefficient.
scale :: Tm f => SR.SemiRingRepr sr -> SR.Coefficient sr -> WeightedSum f sr -> WeightedSum f sr
scale :: SemiRingRepr sr
-> Coefficient sr -> WeightedSum f sr -> WeightedSum f sr
scale SemiRingRepr sr
sr Coefficient sr
c WeightedSum f sr
wsum
  | SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr Coefficient sr
c (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) = SemiRingRepr sr -> Coefficient sr -> WeightedSum f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr -> Coefficient sr -> WeightedSum f sr
constant SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)
  | Bool
otherwise = SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr SumMap f sr
m' (SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.mul SemiRingRepr sr
sr Coefficient sr
c (WeightedSum f sr
wsumWeightedSum f sr
-> Getting (Coefficient sr) (WeightedSum f sr) (Coefficient sr)
-> Coefficient sr
forall s a. s -> Getting a s a -> a
^.Getting (Coefficient sr) (WeightedSum f sr) (Coefficient sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset))
  where
    m' :: SumMap f sr
m' = Identity (SumMap f sr) -> SumMap f sr
forall a. Identity a -> a
runIdentity ((WrapF f sr
 -> Note sr
 -> Coefficient sr
 -> Identity (Maybe (Note sr, Coefficient sr)))
-> SumMap f sr -> Identity (SumMap f sr)
forall (f :: Type -> Type) k v1 v2 a1 a2.
(Applicative f, Ord k, Semigroup v1, Semigroup v2) =>
(k -> v1 -> a1 -> f (Maybe (v2, a2)))
-> AnnotatedMap k v1 a1 -> f (AnnotatedMap k v2 a2)
AM.traverseMaybeWithKey WrapF f sr
-> Note sr
-> Coefficient sr
-> Identity (Maybe (Note sr, Coefficient sr))
f (WeightedSum f sr
wsumWeightedSum f sr
-> Getting (SumMap f sr) (WeightedSum f sr) (SumMap f sr)
-> SumMap f sr
forall s a. s -> Getting a s a -> a
^.Getting (SumMap f sr) (WeightedSum f sr) (SumMap f sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
HashableF f =>
Lens' (WeightedSum f sr) (SumMap f sr)
sumMap))
    f :: WrapF f sr
-> Note sr
-> Coefficient sr
-> Identity (Maybe (Note sr, Coefficient sr))
f (WrapF f (SemiRingBase sr)
t) Note sr
_ Coefficient sr
x
      | SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) Coefficient sr
cx = Maybe (Note sr, Coefficient sr)
-> Identity (Maybe (Note sr, Coefficient sr))
forall (m :: Type -> Type) a. Monad m => a -> m a
return Maybe (Note sr, Coefficient sr)
forall a. Maybe a
Nothing
      | Bool
otherwise = Maybe (Note sr, Coefficient sr)
-> Identity (Maybe (Note sr, Coefficient sr))
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((Note sr, Coefficient sr) -> Maybe (Note sr, Coefficient sr)
forall a. a -> Maybe a
Just (SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
cx f (SemiRingBase sr)
t, Coefficient sr
cx))
      where cx :: Coefficient sr
cx = SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
forall (sr :: SemiRing).
SemiRingRepr sr
-> Coefficient sr -> Coefficient sr -> Coefficient sr
SR.mul SemiRingRepr sr
sr Coefficient sr
c Coefficient sr
x

-- | Produce a weighted sum from a list of terms and an offset.
fromTerms ::
  Tm f =>
  SR.SemiRingRepr sr ->
  [(f (SR.SemiRingBase sr), SR.Coefficient sr)] ->
  SR.Coefficient sr ->
  WeightedSum f sr
fromTerms :: SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)]
-> Coefficient sr
-> WeightedSum f sr
fromTerms SemiRingRepr sr
sr [(f (SemiRingBase sr), Coefficient sr)]
tms Coefficient sr
offset = SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)] -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)] -> SumMap f sr
fromListSumMap SemiRingRepr sr
sr [(f (SemiRingBase sr), Coefficient sr)]
tms) Coefficient sr
offset

-- | Apply update functions to the terms and coefficients of a weighted sum.
transformSum :: (Applicative m, Tm g) =>
  SR.SemiRingRepr sr' ->
  (SR.Coefficient sr -> m (SR.Coefficient sr')) ->
  (f (SR.SemiRingBase sr) -> m (g (SR.SemiRingBase sr'))) ->
  WeightedSum f sr ->
  m (WeightedSum g sr')
transformSum :: SemiRingRepr sr'
-> (Coefficient sr -> m (Coefficient sr'))
-> (f (SemiRingBase sr) -> m (g (SemiRingBase sr')))
-> WeightedSum f sr
-> m (WeightedSum g sr')
transformSum SemiRingRepr sr'
sr' Coefficient sr -> m (Coefficient sr')
transCoef f (SemiRingBase sr) -> m (g (SemiRingBase sr'))
transTm WeightedSum f sr
s = SemiRingRepr sr'
-> [(g (SemiRingBase sr'), Coefficient sr')]
-> Coefficient sr'
-> WeightedSum g sr'
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> [(f (SemiRingBase sr), Coefficient sr)]
-> Coefficient sr
-> WeightedSum f sr
fromTerms SemiRingRepr sr'
sr' ([(g (SemiRingBase sr'), Coefficient sr')]
 -> Coefficient sr' -> WeightedSum g sr')
-> m [(g (SemiRingBase sr'), Coefficient sr')]
-> m (Coefficient sr' -> WeightedSum g sr')
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> m [(g (SemiRingBase sr'), Coefficient sr')]
tms m (Coefficient sr' -> WeightedSum g sr')
-> m (Coefficient sr') -> m (WeightedSum g sr')
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> m (Coefficient sr')
c
  where
    f :: (f (SemiRingBase sr), Coefficient sr)
-> m (g (SemiRingBase sr'), Coefficient sr')
f (f (SemiRingBase sr)
t, Coefficient sr
x) = (,) (g (SemiRingBase sr')
 -> Coefficient sr' -> (g (SemiRingBase sr'), Coefficient sr'))
-> m (g (SemiRingBase sr'))
-> m (Coefficient sr' -> (g (SemiRingBase sr'), Coefficient sr'))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> f (SemiRingBase sr) -> m (g (SemiRingBase sr'))
transTm f (SemiRingBase sr)
t m (Coefficient sr' -> (g (SemiRingBase sr'), Coefficient sr'))
-> m (Coefficient sr') -> m (g (SemiRingBase sr'), Coefficient sr')
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Coefficient sr -> m (Coefficient sr')
transCoef Coefficient sr
x
    tms :: m [(g (SemiRingBase sr'), Coefficient sr')]
tms = ((f (SemiRingBase sr), Coefficient sr)
 -> m (g (SemiRingBase sr'), Coefficient sr'))
-> [(f (SemiRingBase sr), Coefficient sr)]
-> m [(g (SemiRingBase sr'), Coefficient sr')]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (f (SemiRingBase sr), Coefficient sr)
-> m (g (SemiRingBase sr'), Coefficient sr')
f (SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (WeightedSum f sr -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
s))
    c :: m (Coefficient sr')
c   = Coefficient sr -> m (Coefficient sr')
transCoef (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
s)


-- | Evaluate a sum given interpretations of addition, scalar
-- multiplication, and a constant. This evaluation is threaded through
-- a monad. The addition function is associated to the left, as in
-- 'foldlM'.
evalM :: Monad m =>
  (r -> r -> m r) {- ^ Addition function -} ->
  (SR.Coefficient sr -> f (SR.SemiRingBase sr) -> m r) {- ^ Scalar multiply -} ->
  (SR.Coefficient sr -> m r) {- ^ Constant evaluation -} ->
  WeightedSum f sr ->
  m r
evalM :: (r -> r -> m r)
-> (Coefficient sr -> f (SemiRingBase sr) -> m r)
-> (Coefficient sr -> m r)
-> WeightedSum f sr
-> m r
evalM r -> r -> m r
addFn Coefficient sr -> f (SemiRingBase sr) -> m r
smul Coefficient sr -> m r
cnst WeightedSum f sr
sm
  | SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
sm) (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) =
      case SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (WeightedSum f sr -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
sm) of
        []             -> Coefficient sr -> m r
cnst (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)
        ((f (SemiRingBase sr)
e, Coefficient sr
s) : [(f (SemiRingBase sr), Coefficient sr)]
tms) -> [(f (SemiRingBase sr), Coefficient sr)] -> r -> m r
go [(f (SemiRingBase sr), Coefficient sr)]
tms (r -> m r) -> m r -> m r
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Coefficient sr -> f (SemiRingBase sr) -> m r
smul Coefficient sr
s f (SemiRingBase sr)
e

  | Bool
otherwise =
      [(f (SemiRingBase sr), Coefficient sr)] -> r -> m r
go (SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (WeightedSum f sr -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
sm)) (r -> m r) -> m r -> m r
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Coefficient sr -> m r
cnst (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
sm)

  where
    sr :: SemiRingRepr sr
sr = WeightedSum f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
sm

    go :: [(f (SemiRingBase sr), Coefficient sr)] -> r -> m r
go [] r
x = r -> m r
forall (m :: Type -> Type) a. Monad m => a -> m a
return r
x
    go ((f (SemiRingBase sr)
e, Coefficient sr
s) : [(f (SemiRingBase sr), Coefficient sr)]
tms) r
x = [(f (SemiRingBase sr), Coefficient sr)] -> r -> m r
go [(f (SemiRingBase sr), Coefficient sr)]
tms (r -> m r) -> m r -> m r
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< r -> r -> m r
addFn r
x (r -> m r) -> m r -> m r
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Coefficient sr -> f (SemiRingBase sr) -> m r
smul Coefficient sr
s f (SemiRingBase sr)
e

-- | Evaluate a sum given interpretations of addition, scalar multiplication, and
-- a constant rational.
eval ::
  (r -> r -> r) {- ^ Addition function -} ->
  (SR.Coefficient sr -> f (SR.SemiRingBase sr) -> r) {- ^ Scalar multiply -} ->
  (SR.Coefficient sr -> r) {- ^ Constant evaluation -} ->
  WeightedSum f sr ->
  r
eval :: (r -> r -> r)
-> (Coefficient sr -> f (SemiRingBase sr) -> r)
-> (Coefficient sr -> r)
-> WeightedSum f sr
-> r
eval r -> r -> r
addFn Coefficient sr -> f (SemiRingBase sr) -> r
smul Coefficient sr -> r
cnst WeightedSum f sr
w
  | SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w) (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr) =
      case SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (WeightedSum f sr -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w) of
        []             -> Coefficient sr -> r
cnst (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)
        ((f (SemiRingBase sr)
e, Coefficient sr
s) : [(f (SemiRingBase sr), Coefficient sr)]
tms) -> [(f (SemiRingBase sr), Coefficient sr)] -> r -> r
go [(f (SemiRingBase sr), Coefficient sr)]
tms (Coefficient sr -> f (SemiRingBase sr) -> r
smul Coefficient sr
s f (SemiRingBase sr)
e)

  | Bool
otherwise =
      [(f (SemiRingBase sr), Coefficient sr)] -> r -> r
go (SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
forall (f :: BaseType -> Type) (sr :: SemiRing).
SumMap f sr -> [(f (SemiRingBase sr), Coefficient sr)]
toListSumMap (WeightedSum f sr -> SumMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SumMap f sr
_sumMap WeightedSum f sr
w)) (Coefficient sr -> r
cnst (WeightedSum f sr -> Coefficient sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> Coefficient sr
_sumOffset WeightedSum f sr
w))

  where
    sr :: SemiRingRepr sr
sr = WeightedSum f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f sr
w

    go :: [(f (SemiRingBase sr), Coefficient sr)] -> r -> r
go [] r
x = r
x
    go ((f (SemiRingBase sr)
e, Coefficient sr
s) : [(f (SemiRingBase sr), Coefficient sr)]
tms) r
x = [(f (SemiRingBase sr), Coefficient sr)] -> r -> r
go [(f (SemiRingBase sr), Coefficient sr)]
tms (r -> r -> r
addFn (Coefficient sr -> f (SemiRingBase sr) -> r
smul Coefficient sr
s f (SemiRingBase sr)
e) r
x)

{-# INLINABLE eval #-}


-- | Reduce a weighted sum of integers modulo a concrete integer.
--   This reduces each of the coefficients modulo the given integer,
--   removing any that are congruent to 0; the offset value is
--   also reduced.
reduceIntSumMod ::
  Tm f =>
  WeightedSum f SR.SemiRingInteger {- ^ The sum to reduce -} ->
  Integer {- ^ The modulus, must not be 0 -} ->
  WeightedSum f SR.SemiRingInteger
reduceIntSumMod :: WeightedSum f SemiRingInteger
-> Integer -> WeightedSum f SemiRingInteger
reduceIntSumMod WeightedSum f SemiRingInteger
ws Integer
k = SemiRingRepr SemiRingInteger
-> SumMap f SemiRingInteger
-> Coefficient SemiRingInteger
-> WeightedSum f SemiRingInteger
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr SemiRingInteger
SR.SemiRingIntegerRepr AnnotatedMap
  (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer
SumMap f SemiRingInteger
m (WeightedSum f SemiRingInteger
wsWeightedSum f SemiRingInteger
-> Getting Integer (WeightedSum f SemiRingInteger) Integer
-> Integer
forall s a. s -> Getting a s a -> a
^.Getting Integer (WeightedSum f SemiRingInteger) Integer
forall (f :: BaseType -> Type) (sr :: SemiRing).
Lens' (WeightedSum f sr) (Coefficient sr)
sumOffset Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
k)
  where
    sr :: SemiRingRepr SemiRingInteger
sr = WeightedSum f SemiRingInteger -> SemiRingRepr SemiRingInteger
forall (f :: BaseType -> Type) (sr :: SemiRing).
WeightedSum f sr -> SemiRingRepr sr
sumRepr WeightedSum f SemiRingInteger
ws
    m :: AnnotatedMap
  (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer
m = Identity
  (AnnotatedMap
     (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer)
-> AnnotatedMap
     (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer
forall a. Identity a -> a
runIdentity ((WrapF f SemiRingInteger
 -> Note SemiRingInteger
 -> Integer
 -> Identity (Maybe (Note SemiRingInteger, Integer)))
-> AnnotatedMap
     (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer
-> Identity
     (AnnotatedMap
        (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer)
forall (f :: Type -> Type) k v1 v2 a1 a2.
(Applicative f, Ord k, Semigroup v1, Semigroup v2) =>
(k -> v1 -> a1 -> f (Maybe (v2, a2)))
-> AnnotatedMap k v1 a1 -> f (AnnotatedMap k v2 a2)
AM.traverseMaybeWithKey WrapF f SemiRingInteger
-> Note SemiRingInteger
-> Integer
-> Identity (Maybe (Note SemiRingInteger, Integer))
f (WeightedSum f SemiRingInteger
wsWeightedSum f SemiRingInteger
-> Getting
     (AnnotatedMap
        (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer)
     (WeightedSum f SemiRingInteger)
     (AnnotatedMap
        (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer)
-> AnnotatedMap
     (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer
forall s a. s -> Getting a s a -> a
^.Getting
  (AnnotatedMap
     (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer)
  (WeightedSum f SemiRingInteger)
  (AnnotatedMap
     (WrapF f SemiRingInteger) (Note SemiRingInteger) Integer)
forall (f :: BaseType -> Type) (sr :: SemiRing).
HashableF f =>
Lens' (WeightedSum f sr) (SumMap f sr)
sumMap))
    f :: WrapF f SemiRingInteger
-> Note SemiRingInteger
-> Integer
-> Identity (Maybe (Note SemiRingInteger, Integer))
f (WrapF f (SemiRingBase SemiRingInteger)
t) Note SemiRingInteger
_ Integer
x
      | Integer
x' Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0   = Maybe (Note SemiRingInteger, Integer)
-> Identity (Maybe (Note SemiRingInteger, Integer))
forall (m :: Type -> Type) a. Monad m => a -> m a
return Maybe (Note SemiRingInteger, Integer)
forall a. Maybe a
Nothing
      | Bool
otherwise = Maybe (Note SemiRingInteger, Integer)
-> Identity (Maybe (Note SemiRingInteger, Integer))
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((Note SemiRingInteger, Integer)
-> Maybe (Note SemiRingInteger, Integer)
forall a. a -> Maybe a
Just (SemiRingRepr SemiRingInteger
-> Coefficient SemiRingInteger
-> f (SemiRingBase SemiRingInteger)
-> Note SemiRingInteger
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr SemiRingInteger
sr Integer
Coefficient SemiRingInteger
x' f (SemiRingBase SemiRingInteger)
t, Integer
x'))
      where x' :: Integer
x' = Integer
x Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
k

{-# INLINABLE extractCommon #-}

-- | Given two weighted sums @x@ and @y@, this returns a triple @(z,x',y')@
-- where @x = z + x'@ and @y = z + y'@ and @z@ contains the "common"
-- parts of @x@ and @y@.  We only extract common terms when both
-- terms occur with the same coefficient in each sum.
--
-- This is primarily used to simplify if-then-else expressions to
-- preserve shared subterms.
extractCommon ::
  Tm f =>
  WeightedSum f sr ->
  WeightedSum f sr ->
  (WeightedSum f sr, WeightedSum f sr, WeightedSum f sr)
extractCommon :: WeightedSum f sr
-> WeightedSum f sr
-> (WeightedSum f sr, WeightedSum f sr, WeightedSum f sr)
extractCommon (WeightedSum SumMap f sr
xm Coefficient sr
xc SemiRingRepr sr
sr) (WeightedSum SumMap f sr
ym Coefficient sr
yc SemiRingRepr sr
_) = (WeightedSum f sr
z, WeightedSum f sr
x', WeightedSum f sr
y')
  where
    mergeCommon :: WrapF f sr
-> (Note sr, Coefficient sr)
-> (Note sr, Coefficient sr)
-> Maybe (Note sr, Coefficient sr)
mergeCommon (WrapF f (SemiRingBase sr)
t) (Note sr
_, Coefficient sr
xv) (Note sr
_, Coefficient sr
yv)
      | SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr Coefficient sr
xv Coefficient sr
yv  = (Note sr, Coefficient sr) -> Maybe (Note sr, Coefficient sr)
forall a. a -> Maybe a
Just (SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr -> Coefficient sr -> f (SemiRingBase sr) -> Note sr
mkNote SemiRingRepr sr
sr Coefficient sr
xv f (SemiRingBase sr)
t, Coefficient sr
xv)
      | Bool
otherwise       = Maybe (Note sr, Coefficient sr)
forall a. Maybe a
Nothing

    zm :: SumMap f sr
zm = (WrapF f sr
 -> (Note sr, Coefficient sr)
 -> (Note sr, Coefficient sr)
 -> Maybe (Note sr, Coefficient sr))
-> (SumMap f sr -> SumMap f sr)
-> (SumMap f sr -> SumMap f sr)
-> SumMap f sr
-> SumMap f sr
-> SumMap f sr
forall k u v w a b c.
(Ord k, Semigroup u, Semigroup v, Semigroup w) =>
(k -> (u, a) -> (v, b) -> Maybe (w, c))
-> (AnnotatedMap k u a -> AnnotatedMap k w c)
-> (AnnotatedMap k v b -> AnnotatedMap k w c)
-> AnnotatedMap k u a
-> AnnotatedMap k v b
-> AnnotatedMap k w c
AM.mergeWithKey WrapF f sr
-> (Note sr, Coefficient sr)
-> (Note sr, Coefficient sr)
-> Maybe (Note sr, Coefficient sr)
mergeCommon (SumMap f sr -> SumMap f sr -> SumMap f sr
forall a b. a -> b -> a
const SumMap f sr
forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty) (SumMap f sr -> SumMap f sr -> SumMap f sr
forall a b. a -> b -> a
const SumMap f sr
forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a
AM.empty) SumMap f sr
xm SumMap f sr
ym

    (Coefficient sr
zc, Coefficient sr
xc', Coefficient sr
yc')
      | SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> Coefficient sr -> Bool
SR.eq SemiRingRepr sr
sr Coefficient sr
xc Coefficient sr
yc = (Coefficient sr
xc, SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr, SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr)
      | Bool
otherwise      = (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.zero SemiRingRepr sr
sr, Coefficient sr
xc, Coefficient sr
yc)

    z :: WeightedSum f sr
z = SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr SumMap f sr
zm Coefficient sr
zc

    x' :: WeightedSum f sr
x' = SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (SumMap f sr
xm SumMap f sr -> SumMap f sr -> SumMap f sr
forall k v w a b.
(Ord k, Semigroup v, Semigroup w) =>
AnnotatedMap k v a -> AnnotatedMap k w b -> AnnotatedMap k v a
`AM.difference` SumMap f sr
zm) Coefficient sr
xc'
    y' :: WeightedSum f sr
y' = SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
forall (sr :: SemiRing) (f :: BaseType -> Type).
SemiRingRepr sr
-> SumMap f sr -> Coefficient sr -> WeightedSum f sr
unfilteredSum SemiRingRepr sr
sr (SumMap f sr
ym SumMap f sr -> SumMap f sr -> SumMap f sr
forall k v w a b.
(Ord k, Semigroup v, Semigroup w) =>
AnnotatedMap k v a -> AnnotatedMap k w b -> AnnotatedMap k v a
`AM.difference` SumMap f sr
zm) Coefficient sr
yc'


-- | Returns true if the product is trivial (contains no terms).
nullProd :: SemiRingProduct f sr -> Bool
nullProd :: SemiRingProduct f sr -> Bool
nullProd SemiRingProduct f sr
pd = AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr) -> Bool
forall k v a. AnnotatedMap k v a -> Bool
AM.null (SemiRingProduct f sr
-> AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
pd)

-- | If the product consists of exactly on term, return it.
asProdVar :: SemiRingProduct f sr -> Maybe (f (SR.SemiRingBase sr))
asProdVar :: SemiRingProduct f sr -> Maybe (f (SemiRingBase sr))
asProdVar SemiRingProduct f sr
pd
  | [(WrapF f (SemiRingBase sr)
x, SemiRingRepr sr -> Occurrence sr -> Natural
forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Natural
SR.occ_count SemiRingRepr sr
sr -> Natural
1)] <- AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
-> [(WrapF f sr, Occurrence sr)]
forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (SemiRingProduct f sr
-> AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
pd) = f (SemiRingBase sr) -> Maybe (f (SemiRingBase sr))
forall a. a -> Maybe a
Just f (SemiRingBase sr)
x
  | Bool
otherwise = Maybe (f (SemiRingBase sr))
forall a. Maybe a
Nothing
 where
 sr :: SemiRingRepr sr
sr = SemiRingProduct f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f sr
pd

prodAbsValue :: OrdF f => SemiRingProduct f sr -> AD.AbstractValue (SR.SemiRingBase sr)
prodAbsValue :: SemiRingProduct f sr -> AbstractValue (SemiRingBase sr)
prodAbsValue SemiRingProduct f sr
pd =
  SRAbsValue sr -> AbstractValue (SemiRingBase sr)
forall (sr :: SemiRing).
SRAbsValue sr -> AbstractValue (SemiRingBase sr)
fromSRAbsValue (SRAbsValue sr -> AbstractValue (SemiRingBase sr))
-> SRAbsValue sr -> AbstractValue (SemiRingBase sr)
forall a b. (a -> b) -> a -> b
$
  case AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
-> Maybe (ProdNote sr)
forall k v a. (Ord k, Semigroup v) => AnnotatedMap k v a -> Maybe v
AM.annotation (SemiRingProduct f sr
-> AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
pd) of
    Maybe (ProdNote sr)
Nothing             -> SemiRingRepr sr -> Coefficient sr -> SRAbsValue sr
forall (sr :: SemiRing).
SemiRingRepr sr -> Coefficient sr -> SRAbsValue sr
abstractScalar (SemiRingProduct f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f sr
pd) (SemiRingRepr sr -> Coefficient sr
forall (sr :: SemiRing). SemiRingRepr sr -> Coefficient sr
SR.one (SemiRingProduct f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f sr
pd))
    Just (ProdNote IncrHash
_ SRAbsValue sr
v) -> SRAbsValue sr
v

-- | Returns true if the product contains at least on occurrence of the given term.
prodContains :: OrdF f => SemiRingProduct f sr -> f (SR.SemiRingBase sr) -> Bool
prodContains :: SemiRingProduct f sr -> f (SemiRingBase sr) -> Bool
prodContains SemiRingProduct f sr
pd f (SemiRingBase sr)
x = Maybe (ProdNote sr, Occurrence sr) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (ProdNote sr, Occurrence sr) -> Bool)
-> Maybe (ProdNote sr, Occurrence sr) -> Bool
forall a b. (a -> b) -> a -> b
$ WrapF f sr
-> AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
-> Maybe (ProdNote sr, Occurrence sr)
forall k v a.
(Ord k, Semigroup v) =>
k -> AnnotatedMap k v a -> Maybe (v, a)
AM.lookup (f (SemiRingBase sr) -> WrapF f sr
forall (f :: BaseType -> Type) (i :: SemiRing).
f (SemiRingBase i) -> WrapF f i
WrapF f (SemiRingBase sr)
x) (SemiRingProduct f sr
-> AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
pd)

-- | Produce a product map from a raw map of terms to occurrences.
--   PRECONDITION: the occurrence value for each term should be non-zero.
mkProd :: HashableF f => SR.SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd :: SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd SemiRingRepr sr
sr ProdMap f sr
m = ProdMap f sr -> SemiRingRepr sr -> SemiRingProduct f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
ProdMap f sr -> SemiRingRepr sr -> SemiRingProduct f sr
SemiRingProduct ProdMap f sr
m SemiRingRepr sr
sr

-- | Produce a product representing the single given term.
prodVar :: Tm f => SR.SemiRingRepr sr -> f (SR.SemiRingBase sr) -> SemiRingProduct f sr
prodVar :: SemiRingRepr sr -> f (SemiRingBase sr) -> SemiRingProduct f sr
prodVar SemiRingRepr sr
sr f (SemiRingBase sr)
x = SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
HashableF f =>
SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd SemiRingRepr sr
sr (SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
Tm f =>
SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdMap f sr
singletonProdMap SemiRingRepr sr
sr (SemiRingRepr sr -> Occurrence sr
forall (sr :: SemiRing). SemiRingRepr sr -> Occurrence sr
SR.occ_one SemiRingRepr sr
sr) f (SemiRingBase sr)
x)

-- | Multiply two products, collecting terms and adding occurrences.
prodMul :: Tm f => SemiRingProduct f sr -> SemiRingProduct f sr -> SemiRingProduct f sr
prodMul :: SemiRingProduct f sr
-> SemiRingProduct f sr -> SemiRingProduct f sr
prodMul SemiRingProduct f sr
x SemiRingProduct f sr
y = SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
HashableF f =>
SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd SemiRingRepr sr
sr ProdMap f sr
m
  where
  sr :: SemiRingRepr sr
sr = SemiRingProduct f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f sr
x
  mergeCommon :: WrapF f sr
-> (ProdNote sr, Occurrence sr)
-> (ProdNote sr, Occurrence sr)
-> Maybe (ProdNote sr, Occurrence sr)
mergeCommon (WrapF f (SemiRingBase sr)
k) (ProdNote sr
_,Occurrence sr
a) (ProdNote sr
_,Occurrence sr
b) = (ProdNote sr, Occurrence sr) -> Maybe (ProdNote sr, Occurrence sr)
forall a. a -> Maybe a
Just (SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdNote sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
(HashableF f, HasAbsValue f) =>
SemiRingRepr sr
-> Occurrence sr -> f (SemiRingBase sr) -> ProdNote sr
mkProdNote SemiRingRepr sr
sr Occurrence sr
c f (SemiRingBase sr)
k, Occurrence sr
c)
     where c :: Occurrence sr
c = SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Occurrence sr
forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Occurrence sr -> Occurrence sr
SR.occ_add SemiRingRepr sr
sr Occurrence sr
a Occurrence sr
b
  m :: ProdMap f sr
m = (WrapF f sr
 -> (ProdNote sr, Occurrence sr)
 -> (ProdNote sr, Occurrence sr)
 -> Maybe (ProdNote sr, Occurrence sr))
-> (ProdMap f sr -> ProdMap f sr)
-> (ProdMap f sr -> ProdMap f sr)
-> ProdMap f sr
-> ProdMap f sr
-> ProdMap f sr
forall k u v w a b c.
(Ord k, Semigroup u, Semigroup v, Semigroup w) =>
(k -> (u, a) -> (v, b) -> Maybe (w, c))
-> (AnnotatedMap k u a -> AnnotatedMap k w c)
-> (AnnotatedMap k v b -> AnnotatedMap k w c)
-> AnnotatedMap k u a
-> AnnotatedMap k v b
-> AnnotatedMap k w c
AM.mergeWithKey WrapF f sr
-> (ProdNote sr, Occurrence sr)
-> (ProdNote sr, Occurrence sr)
-> Maybe (ProdNote sr, Occurrence sr)
mergeCommon ProdMap f sr -> ProdMap f sr
forall a. a -> a
id ProdMap f sr -> ProdMap f sr
forall a. a -> a
id (SemiRingProduct f sr -> ProdMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
x) (SemiRingProduct f sr -> ProdMap f sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
y)

-- | Evaluate a product, given a function representing multiplication
--   and a function to evaluate terms.
prodEval ::
  (r -> r -> r) {-^ multiplication evalation -} ->
  (f (SR.SemiRingBase sr) -> r) {-^ term evaluation -} ->
  SemiRingProduct f sr ->
  Maybe r
prodEval :: (r -> r -> r)
-> (f (SemiRingBase sr) -> r) -> SemiRingProduct f sr -> Maybe r
prodEval r -> r -> r
mul f (SemiRingBase sr) -> r
tm SemiRingProduct f sr
om =
  Identity (Maybe r) -> Maybe r
forall a. Identity a -> a
runIdentity ((r -> r -> Identity r)
-> (f (SemiRingBase sr) -> Identity r)
-> SemiRingProduct f sr
-> Identity (Maybe r)
forall (m :: Type -> Type) r (f :: BaseType -> Type)
       (sr :: SemiRing).
Monad m =>
(r -> r -> m r)
-> (f (SemiRingBase sr) -> m r)
-> SemiRingProduct f sr
-> m (Maybe r)
prodEvalM (\r
x r
y -> r -> Identity r
forall a. a -> Identity a
Identity (r -> r -> r
mul r
x r
y)) (r -> Identity r
forall a. a -> Identity a
Identity (r -> Identity r)
-> (f (SemiRingBase sr) -> r) -> f (SemiRingBase sr) -> Identity r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (SemiRingBase sr) -> r
tm) SemiRingProduct f sr
om)

-- | Evaluate a product, given a function representing multiplication
--   and a function to evaluate terms, where both functions are threaded
--   through a monad.
prodEvalM :: Monad m =>
  (r -> r -> m r) {-^ multiplication evalation -} ->
  (f (SR.SemiRingBase sr) -> m r) {-^ term evaluation -} ->
  SemiRingProduct f sr ->
  m (Maybe r)
prodEvalM :: (r -> r -> m r)
-> (f (SemiRingBase sr) -> m r)
-> SemiRingProduct f sr
-> m (Maybe r)
prodEvalM r -> r -> m r
mul f (SemiRingBase sr) -> m r
tm SemiRingProduct f sr
om = [(WrapF f sr, Occurrence sr)] -> m (Maybe r)
f (AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
-> [(WrapF f sr, Occurrence sr)]
forall k v a. AnnotatedMap k v a -> [(k, a)]
AM.toList (SemiRingProduct f sr
-> AnnotatedMap (WrapF f sr) (ProdNote sr) (Occurrence sr)
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> ProdMap f sr
_prodMap SemiRingProduct f sr
om))
  where
  sr :: SemiRingRepr sr
sr = SemiRingProduct f sr -> SemiRingRepr sr
forall (f :: BaseType -> Type) (sr :: SemiRing).
SemiRingProduct f sr -> SemiRingRepr sr
prodRepr SemiRingProduct f sr
om

  -- we have not yet encountered a term with non-zero occurrences
  f :: [(WrapF f sr, Occurrence sr)] -> m (Maybe r)
f [] = Maybe r -> m (Maybe r)
forall (m :: Type -> Type) a. Monad m => a -> m a
return Maybe r
forall a. Maybe a
Nothing
  f ((WrapF f (SemiRingBase sr)
x, SemiRingRepr sr -> Occurrence sr -> Natural
forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Natural
SR.occ_count SemiRingRepr sr
sr -> Natural
n):[(WrapF f sr, Occurrence sr)]
xs)
    | Natural
n Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
0    = [(WrapF f sr, Occurrence sr)] -> m (Maybe r)
f [(WrapF f sr, Occurrence sr)]
xs
    | Bool
otherwise =
        do r
t <- f (SemiRingBase sr) -> m r
tm f (SemiRingBase sr)
x
           r
t' <- Natural -> r -> r -> m r
go (Natural
nNatural -> Natural -> Natural
forall a. Num a => a -> a -> a
-Natural
1) r
t r
t
           [(WrapF f sr, Occurrence sr)] -> r -> m (Maybe r)
g [(WrapF f sr, Occurrence sr)]
xs r
t'

  -- we have a partial product @z@ already computed and need to multiply
  -- in the remaining terms in the list
  g :: [(WrapF f sr, Occurrence sr)] -> r -> m (Maybe r)
g [] r
z = Maybe r -> m (Maybe r)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (r -> Maybe r
forall a. a -> Maybe a
Just r
z)
  g ((WrapF f (SemiRingBase sr)
x, SemiRingRepr sr -> Occurrence sr -> Natural
forall (sr :: SemiRing).
SemiRingRepr sr -> Occurrence sr -> Natural
SR.occ_count SemiRingRepr sr
sr -> Natural
n):[(WrapF f sr, Occurrence sr)]
xs) r
z
    | Natural
n Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
0    = [(WrapF f sr, Occurrence sr)] -> r -> m (Maybe r)
g [(WrapF f sr, Occurrence sr)]
xs r
z
    | Bool
otherwise =
        do r
t <- f (SemiRingBase sr) -> m r
tm f (SemiRingBase sr)
x
           r
t' <- Natural -> r -> r -> m r
go Natural
n r
t r
z
           [(WrapF f sr, Occurrence sr)] -> r -> m (Maybe r)
g [(WrapF f sr, Occurrence sr)]
xs r
t'

  -- compute: z * t^n
  go :: Natural -> r -> r -> m r
go Natural
n r
t r
z
    | Natural
n Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
> Natural
0 = Natural -> r -> r -> m r
go (Natural
nNatural -> Natural -> Natural
forall a. Num a => a -> a -> a
-Natural
1) r
t (r -> m r) -> m r -> m r
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< r -> r -> m r
mul r
z r
t
    | Bool
otherwise = r -> m r
forall (m :: Type -> Type) a. Monad m => a -> m a
return r
z