{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wwarn #-}
module What4.Expr.WeightedSum
(
Tm
, WeightedSum
, sumRepr
, sumOffset
, sumAbsValue
, constant
, var
, scaledVar
, asConstant
, asVar
, asWeightedVar
, asAffineVar
, isZero
, traverseVars
, traverseCoeffs
, add
, addVar
, addVars
, addConstant
, scale
, eval
, evalM
, extractCommon
, fromTerms
, transformSum
, reduceIntSumMod
, SemiRingProduct
, traverseProdVars
, nullProd
, asProdVar
, prodRepr
, prodVar
, prodAbsValue
, prodMul
, prodEval
, prodEvalM
, prodContains
) where
import Control.Lens
import Control.Monad.State
import qualified Data.BitVector.Sized as BV
import Data.Hashable
import Data.Kind
import Data.List (foldl')
import Data.Maybe
import Data.Parameterized.Classes
import What4.BaseTypes
import qualified What4.SemiRing as SR
import What4.Utils.AnnotatedMap (AnnotatedMap)
import qualified What4.Utils.AnnotatedMap as AM
import qualified What4.Utils.AbstractDomains as AD
import qualified What4.Utils.BVDomain.Arith as A
import qualified What4.Utils.BVDomain.XOR as X
import qualified What4.Utils.BVDomain as BVD
import What4.Utils.IncrHash
data SRAbsValue :: SR.SemiRing -> Type where
SRAbsNatAdd :: !AD.NatValueRange -> SRAbsValue SR.SemiRingNat
SRAbsIntAdd :: !(AD.ValueRange Integer) -> SRAbsValue SR.SemiRingInteger
SRAbsRealAdd :: !AD.RealAbstractValue -> SRAbsValue SR.SemiRingReal
SRAbsBVAdd :: (1 <= w) => !(A.Domain w) -> SRAbsValue (SR.SemiRingBV SR.BVArith w)
SRAbsBVXor :: (1 <= w) => !(X.Domain w) -> SRAbsValue (SR.SemiRingBV SR.BVBits w)
instance Semigroup (SRAbsValue sr) where
SRAbsNatAdd x <> SRAbsNatAdd y = SRAbsNatAdd (AD.natRangeAdd x y)
SRAbsIntAdd x <> SRAbsIntAdd y = SRAbsIntAdd (AD.addRange x y)
SRAbsRealAdd x <> SRAbsRealAdd y = SRAbsRealAdd (AD.ravAdd x y)
SRAbsBVAdd x <> SRAbsBVAdd y = SRAbsBVAdd (A.add x y)
SRAbsBVXor x <> SRAbsBVXor y = SRAbsBVXor (X.xor x y)
(.**) :: SRAbsValue sr -> SRAbsValue sr -> SRAbsValue sr
SRAbsNatAdd x .** SRAbsNatAdd y = SRAbsNatAdd (AD.natRangeMul x y)
SRAbsIntAdd x .** SRAbsIntAdd y = SRAbsIntAdd (AD.mulRange x y)
SRAbsRealAdd x .** SRAbsRealAdd y = SRAbsRealAdd (AD.ravMul x y)
SRAbsBVAdd x .** SRAbsBVAdd y = SRAbsBVAdd (A.mul x y)
SRAbsBVXor x .** SRAbsBVXor y = SRAbsBVXor (X.and x y)
abstractTerm ::
AD.HasAbsValue f =>
SR.SemiRingRepr sr -> SR.Coefficient sr -> f (SR.SemiRingBase sr) -> SRAbsValue sr
abstractTerm sr c e =
case sr of
SR.SemiRingNatRepr -> SRAbsNatAdd (AD.natRangeScalarMul c (AD.getAbsValue e))
SR.SemiRingIntegerRepr -> SRAbsIntAdd (AD.rangeScalarMul c (AD.getAbsValue e))
SR.SemiRingRealRepr -> SRAbsRealAdd (AD.ravScalarMul c (AD.getAbsValue e))
SR.SemiRingBVRepr fv w ->
case fv of
SR.BVArithRepr ->
SRAbsBVAdd (A.scale (BV.asSigned w c) (BVD.asArithDomain (AD.getAbsValue e)))
SR.BVBitsRepr -> SRAbsBVXor (X.and_scalar (BV.asUnsigned c) (BVD.asXorDomain (AD.getAbsValue e)))
abstractVal :: AD.HasAbsValue f => SR.SemiRingRepr sr -> f (SR.SemiRingBase sr) -> SRAbsValue sr
abstractVal sr e =
case sr of
SR.SemiRingNatRepr -> SRAbsNatAdd (AD.getAbsValue e)
SR.SemiRingIntegerRepr -> SRAbsIntAdd (AD.getAbsValue e)
SR.SemiRingRealRepr -> SRAbsRealAdd (AD.getAbsValue e)
SR.SemiRingBVRepr fv _w ->
case fv of
SR.BVArithRepr -> SRAbsBVAdd (BVD.asArithDomain (AD.getAbsValue e))
SR.BVBitsRepr -> SRAbsBVXor (BVD.asXorDomain (AD.getAbsValue e))
abstractScalar ::
SR.SemiRingRepr sr -> SR.Coefficient sr -> SRAbsValue sr
abstractScalar sr c =
case sr of
SR.SemiRingNatRepr -> SRAbsNatAdd (AD.natSingleRange c)
SR.SemiRingIntegerRepr -> SRAbsIntAdd (AD.SingleRange c)
SR.SemiRingRealRepr -> SRAbsRealAdd (AD.ravSingle c)
SR.SemiRingBVRepr fv w ->
case fv of
SR.BVArithRepr -> SRAbsBVAdd (A.singleton w (BV.asUnsigned c))
SR.BVBitsRepr -> SRAbsBVXor (X.singleton w (BV.asUnsigned c))
fromSRAbsValue ::
SRAbsValue sr -> AD.AbstractValue (SR.SemiRingBase sr)
fromSRAbsValue v =
case v of
SRAbsNatAdd x -> x
SRAbsIntAdd x -> x
SRAbsRealAdd x -> x
SRAbsBVAdd x -> BVD.BVDArith x
SRAbsBVXor x -> BVD.fromXorDomain x
type Tm f = (HashableF f, OrdF f, AD.HasAbsValue f)
newtype WrapF (f :: BaseType -> Type) (i :: SR.SemiRing) = WrapF (f (SR.SemiRingBase i))
instance OrdF f => Ord (WrapF f i) where
compare (WrapF x) (WrapF y) = toOrdering $ compareF x y
instance TestEquality f => Eq (WrapF f i) where
(WrapF x) == (WrapF y) = isJust $ testEquality x y
instance HashableF f => Hashable (WrapF f i) where
hashWithSalt s (WrapF x) = hashWithSaltF s x
traverseWrap :: Functor m => (f (SR.SemiRingBase i) -> m (g (SR.SemiRingBase i))) -> WrapF f i -> m (WrapF g i)
traverseWrap f (WrapF x) = WrapF <$> f x
data Note sr = Note !IncrHash !(SRAbsValue sr)
instance Semigroup (Note sr) where
Note h1 d1 <> Note h2 d2 = Note (h1 <> h2) (d1 <> d2)
data ProdNote sr = ProdNote !IncrHash !(SRAbsValue sr)
instance Semigroup (ProdNote sr) where
ProdNote h1 d1 <> ProdNote h2 d2 = ProdNote (h1 <> h2) (d1 .** d2)
mkNote ::
(HashableF f, AD.HasAbsValue f) =>
SR.SemiRingRepr sr -> SR.Coefficient sr -> f (SR.SemiRingBase sr) -> Note sr
mkNote sr c t = Note (mkIncrHash h) d
where
h = SR.sr_hashWithSalt sr (hashF t) c
d = abstractTerm sr c t
mkProdNote ::
(HashableF f, AD.HasAbsValue f) =>
SR.SemiRingRepr sr ->
SR.Occurrence sr ->
f (SR.SemiRingBase sr) ->
ProdNote sr
mkProdNote sr occ t = ProdNote (mkIncrHash h) d
where
h = SR.occ_hashWithSalt sr (hashF t) occ
v = abstractVal sr t
power = fromIntegral (SR.occ_count sr occ)
d = go (power - 1) v
go (n::Integer) x
| n > 0 = go (n-1) (v .** x)
| otherwise = x
type SumMap f sr = AnnotatedMap (WrapF f sr) (Note sr) (SR.Coefficient sr)
type ProdMap f sr = AnnotatedMap (WrapF f sr) (ProdNote sr) (SR.Occurrence sr)
insertSumMap ::
Tm f =>
SR.SemiRingRepr sr ->
SR.Coefficient sr -> f (SR.SemiRingBase sr) -> SumMap f sr -> SumMap f sr
insertSumMap sr c t = AM.alter f (WrapF t)
where
f Nothing = Just (mkNote sr c t, c)
f (Just (_, c0))
| SR.eq sr (SR.zero sr) c' = Nothing
| otherwise = Just (mkNote sr c' t, c')
where c' = SR.add sr c0 c
singletonSumMap ::
Tm f =>
SR.SemiRingRepr sr ->
SR.Coefficient sr -> f (SR.SemiRingBase sr) -> SumMap f sr
singletonSumMap sr c t = AM.singleton (WrapF t) (mkNote sr c t) c
singletonProdMap ::
Tm f =>
SR.SemiRingRepr sr ->
SR.Occurrence sr ->
f (SR.SemiRingBase sr) ->
ProdMap f sr
singletonProdMap sr occ t = AM.singleton (WrapF t) (mkProdNote sr occ t) occ
fromListSumMap ::
Tm f =>
SR.SemiRingRepr sr ->
[(f (SR.SemiRingBase sr), SR.Coefficient sr)] -> SumMap f sr
fromListSumMap _ [] = AM.empty
fromListSumMap sr ((t, c) : xs) = insertSumMap sr c t (fromListSumMap sr xs)
toListSumMap :: SumMap f sr -> [(f (SR.SemiRingBase sr), SR.Coefficient sr)]
toListSumMap am = [ (t, c) | (WrapF t, c) <- AM.toList am ]
data WeightedSum (f :: BaseType -> Type) (sr :: SR.SemiRing)
= WeightedSum { _sumMap :: !(SumMap f sr)
, _sumOffset :: !(SR.Coefficient sr)
, sumRepr :: !(SR.SemiRingRepr sr)
}
data SemiRingProduct (f :: BaseType -> Type) (sr :: SR.SemiRing)
= SemiRingProduct { _prodMap :: !(ProdMap f sr)
, prodRepr :: !(SR.SemiRingRepr sr)
}
sumMapHash :: OrdF f => WeightedSum f sr -> IncrHash
sumMapHash x =
case AM.annotation (_sumMap x) of
Nothing -> mempty
Just (Note h _) -> h
prodMapHash :: OrdF f => SemiRingProduct f sr -> IncrHash
prodMapHash pd =
case AM.annotation (_prodMap pd) of
Nothing -> mempty
Just (ProdNote h _) -> h
sumAbsValue :: OrdF f => WeightedSum f sr -> AD.AbstractValue (SR.SemiRingBase sr)
sumAbsValue wsum =
fromSRAbsValue $
case AM.annotation (_sumMap wsum) of
Nothing -> absOffset
Just (Note _ v) -> absOffset <> v
where
absOffset = abstractScalar (sumRepr wsum) (_sumOffset wsum)
instance OrdF f => TestEquality (SemiRingProduct f) where
testEquality x y
| prodMapHash x /= prodMapHash y = Nothing
| otherwise =
do Refl <- testEquality (prodRepr x) (prodRepr y)
unless (AM.eqBy (SR.occ_eq (prodRepr x)) (_prodMap x) (_prodMap y)) Nothing
return Refl
instance OrdF f => TestEquality (WeightedSum f) where
testEquality x y
| sumMapHash x /= sumMapHash y = Nothing
| otherwise =
do Refl <- testEquality (sumRepr x) (sumRepr y)
unless (SR.eq (sumRepr x) (_sumOffset x) (_sumOffset y)) Nothing
unless (AM.eqBy (SR.eq (sumRepr x)) (_sumMap x) (_sumMap y)) Nothing
return Refl
unfilteredSum ::
SR.SemiRingRepr sr ->
SumMap f sr ->
SR.Coefficient sr ->
WeightedSum f sr
unfilteredSum sr m c = WeightedSum m c sr
sumMap :: HashableF f => Lens' (WeightedSum f sr) (SumMap f sr)
sumMap = lens _sumMap (\w m -> w{ _sumMap = m })
sumOffset :: Lens' (WeightedSum f sr) (SR.Coefficient sr)
sumOffset = lens _sumOffset (\s v -> s { _sumOffset = v })
instance OrdF f => Hashable (WeightedSum f sr) where
hashWithSalt s0 w =
hashWithSalt (SR.sr_hashWithSalt (sumRepr w) s0 (_sumOffset w)) (sumMapHash w)
instance OrdF f => Hashable (SemiRingProduct f sr) where
hashWithSalt s0 w = hashWithSalt s0 (prodMapHash w)
asConstant :: WeightedSum f sr -> Maybe (SR.Coefficient sr)
asConstant w
| AM.null (_sumMap w) = Just (_sumOffset w)
| otherwise = Nothing
isZero :: SR.SemiRingRepr sr -> WeightedSum f sr -> Bool
isZero sr s =
case asConstant s of
Just c -> SR.sr_compare sr (SR.zero sr) c == EQ
Nothing -> False
asAffineVar :: WeightedSum f sr -> Maybe (SR.Coefficient sr, f (SR.SemiRingBase sr), SR.Coefficient sr)
asAffineVar w
| [(WrapF r, c)] <- AM.toList (_sumMap w)
= Just (c,r,_sumOffset w)
| otherwise
= Nothing
asWeightedVar :: WeightedSum f sr -> Maybe (SR.Coefficient sr, f (SR.SemiRingBase sr))
asWeightedVar w
| [(WrapF r, c)] <- AM.toList (_sumMap w)
, let sr = sumRepr w
, SR.eq sr (SR.zero sr) (_sumOffset w)
= Just (c,r)
| otherwise
= Nothing
asVar :: WeightedSum f sr -> Maybe (f (SR.SemiRingBase sr))
asVar w
| [(WrapF r, c)] <- AM.toList (_sumMap w)
, let sr = sumRepr w
, SR.eq sr (SR.one sr) c
, SR.eq sr (SR.zero sr) (_sumOffset w)
= Just r
| otherwise
= Nothing
constant :: Tm f => SR.SemiRingRepr sr -> SR.Coefficient sr -> WeightedSum f sr
constant sr c = unfilteredSum sr AM.empty c
traverseVars :: forall k j m sr.
(Applicative m, Tm k) =>
(j (SR.SemiRingBase sr) -> m (k (SR.SemiRingBase sr))) ->
WeightedSum j sr ->
m (WeightedSum k sr)
traverseVars f w =
(\tms -> fromTerms sr tms (_sumOffset w)) <$>
traverse (_1 f) (toListSumMap (_sumMap w))
where sr = sumRepr w
traverseCoeffs :: forall m f sr.
(Applicative m, Tm f) =>
(SR.Coefficient sr -> m (SR.Coefficient sr)) ->
WeightedSum f sr ->
m (WeightedSum f sr)
traverseCoeffs f w =
unfilteredSum sr <$> AM.traverseMaybeWithKey g (_sumMap w) <*> f (_sumOffset w)
where
sr = sumRepr w
g (WrapF t) _ c = mk t <$> f c
mk t c = if SR.eq sr (SR.zero sr) c then Nothing else Just (mkNote sr c t, c)
traverseProdVars :: forall k j m sr.
(Applicative m, Tm k) =>
(j (SR.SemiRingBase sr) -> m (k (SR.SemiRingBase sr))) ->
SemiRingProduct j sr ->
m (SemiRingProduct k sr)
traverseProdVars f pd =
mkProd sr . rebuild <$>
traverse (_1 (traverseWrap f)) (AM.toList (_prodMap pd))
where
sr = prodRepr pd
rebuild = foldl' (\m (WrapF t, occ) -> AM.insert (WrapF t) (mkProdNote sr occ t) occ m) AM.empty
scaledVar :: Tm f => SR.SemiRingRepr sr -> SR.Coefficient sr -> f (SR.SemiRingBase sr) -> WeightedSum f sr
scaledVar sr s t
| SR.eq sr (SR.zero sr) s = unfilteredSum sr AM.empty (SR.zero sr)
| otherwise = unfilteredSum sr (singletonSumMap sr s t) (SR.zero sr)
var :: Tm f => SR.SemiRingRepr sr -> f (SR.SemiRingBase sr) -> WeightedSum f sr
var sr t = unfilteredSum sr (singletonSumMap sr (SR.one sr) t) (SR.zero sr)
add ::
Tm f =>
SR.SemiRingRepr sr ->
WeightedSum f sr ->
WeightedSum f sr ->
WeightedSum f sr
add sr x y = unfilteredSum sr zm zc
where
merge (WrapF k) u v | SR.eq sr r (SR.zero sr) = Nothing
| otherwise = Just (mkNote sr r k, r)
where r = SR.add sr u v
zm = AM.unionWithKeyMaybe merge (_sumMap x) (_sumMap y)
zc = SR.add sr (x^.sumOffset) (y^.sumOffset)
addVars ::
Tm f =>
SR.SemiRingRepr sr ->
f (SR.SemiRingBase sr) ->
f (SR.SemiRingBase sr) ->
WeightedSum f sr
addVars sr x y = fromTerms sr [(x, SR.one sr), (y, SR.one sr)] (SR.zero sr)
addVar ::
Tm f =>
SR.SemiRingRepr sr ->
WeightedSum f sr -> f (SR.SemiRingBase sr) -> WeightedSum f sr
addVar sr wsum x = wsum { _sumMap = m' }
where m' = insertSumMap sr (SR.one sr) x (_sumMap wsum)
addConstant :: SR.SemiRingRepr sr -> WeightedSum f sr -> SR.Coefficient sr -> WeightedSum f sr
addConstant sr x r = x & sumOffset %~ SR.add sr r
scale :: Tm f => SR.SemiRingRepr sr -> SR.Coefficient sr -> WeightedSum f sr -> WeightedSum f sr
scale sr c wsum
| SR.eq sr c (SR.zero sr) = constant sr (SR.zero sr)
| otherwise = unfilteredSum sr m' (SR.mul sr c (wsum^.sumOffset))
where
m' = runIdentity (AM.traverseMaybeWithKey f (wsum^.sumMap))
f (WrapF t) _ x
| SR.eq sr (SR.zero sr) cx = return Nothing
| otherwise = return (Just (mkNote sr cx t, cx))
where cx = SR.mul sr c x
fromTerms ::
Tm f =>
SR.SemiRingRepr sr ->
[(f (SR.SemiRingBase sr), SR.Coefficient sr)] ->
SR.Coefficient sr ->
WeightedSum f sr
fromTerms sr tms offset = unfilteredSum sr (fromListSumMap sr tms) offset
transformSum :: (Applicative m, Tm g) =>
SR.SemiRingRepr sr' ->
(SR.Coefficient sr -> m (SR.Coefficient sr')) ->
(f (SR.SemiRingBase sr) -> m (g (SR.SemiRingBase sr'))) ->
WeightedSum f sr ->
m (WeightedSum g sr')
transformSum sr' transCoef transTm s = fromTerms sr' <$> tms <*> c
where
f (t, x) = (,) <$> transTm t <*> transCoef x
tms = traverse f (toListSumMap (_sumMap s))
c = transCoef (_sumOffset s)
evalM :: Monad m =>
(r -> r -> m r) ->
(SR.Coefficient sr -> f (SR.SemiRingBase sr) -> m r) ->
(SR.Coefficient sr -> m r) ->
WeightedSum f sr ->
m r
evalM addFn smul cnst sm
| SR.eq sr (_sumOffset sm) (SR.zero sr) =
case toListSumMap (_sumMap sm) of
[] -> cnst (SR.zero sr)
((e, s) : tms) -> go tms =<< smul s e
| otherwise =
go (toListSumMap (_sumMap sm)) =<< cnst (_sumOffset sm)
where
sr = sumRepr sm
go [] x = return x
go ((e, s) : tms) x = go tms =<< addFn x =<< smul s e
eval ::
(r -> r -> r) ->
(SR.Coefficient sr -> f (SR.SemiRingBase sr) -> r) ->
(SR.Coefficient sr -> r) ->
WeightedSum f sr ->
r
eval addFn smul cnst w
| SR.eq sr (_sumOffset w) (SR.zero sr) =
case toListSumMap (_sumMap w) of
[] -> cnst (SR.zero sr)
((e, s) : tms) -> go tms (smul s e)
| otherwise =
go (toListSumMap (_sumMap w)) (cnst (_sumOffset w))
where
sr = sumRepr w
go [] x = x
go ((e, s) : tms) x = go tms (addFn (smul s e) x)
{-# INLINABLE eval #-}
reduceIntSumMod ::
Tm f =>
WeightedSum f SR.SemiRingInteger ->
Integer ->
WeightedSum f SR.SemiRingInteger
reduceIntSumMod ws k = unfilteredSum SR.SemiRingIntegerRepr m (ws^.sumOffset `mod` k)
where
sr = sumRepr ws
m = runIdentity (AM.traverseMaybeWithKey f (ws^.sumMap))
f (WrapF t) _ x
| x' == 0 = return Nothing
| otherwise = return (Just (mkNote sr x' t, x'))
where x' = x `mod` k
{-# INLINABLE extractCommon #-}
extractCommon ::
Tm f =>
WeightedSum f sr ->
WeightedSum f sr ->
(WeightedSum f sr, WeightedSum f sr, WeightedSum f sr)
extractCommon (WeightedSum xm xc sr) (WeightedSum ym yc _) = (z, x', y')
where
mergeCommon (WrapF t) (_, xv) (_, yv)
| SR.eq sr xv yv = Just (mkNote sr xv t, xv)
| otherwise = Nothing
zm = AM.mergeWithKey mergeCommon (const AM.empty) (const AM.empty) xm ym
(zc, xc', yc')
| SR.eq sr xc yc = (xc, SR.zero sr, SR.zero sr)
| otherwise = (SR.zero sr, xc, yc)
z = unfilteredSum sr zm zc
x' = unfilteredSum sr (xm `AM.difference` zm) xc'
y' = unfilteredSum sr (ym `AM.difference` zm) yc'
nullProd :: SemiRingProduct f sr -> Bool
nullProd pd = AM.null (_prodMap pd)
asProdVar :: SemiRingProduct f sr -> Maybe (f (SR.SemiRingBase sr))
asProdVar pd
| [(WrapF x, SR.occ_count sr -> 1)] <- AM.toList (_prodMap pd) = Just x
| otherwise = Nothing
where
sr = prodRepr pd
prodAbsValue :: OrdF f => SemiRingProduct f sr -> AD.AbstractValue (SR.SemiRingBase sr)
prodAbsValue pd =
fromSRAbsValue $
case AM.annotation (_prodMap pd) of
Nothing -> abstractScalar (prodRepr pd) (SR.one (prodRepr pd))
Just (ProdNote _ v) -> v
prodContains :: OrdF f => SemiRingProduct f sr -> f (SR.SemiRingBase sr) -> Bool
prodContains pd x = isJust $ AM.lookup (WrapF x) (_prodMap pd)
mkProd :: HashableF f => SR.SemiRingRepr sr -> ProdMap f sr -> SemiRingProduct f sr
mkProd sr m = SemiRingProduct m sr
prodVar :: Tm f => SR.SemiRingRepr sr -> f (SR.SemiRingBase sr) -> SemiRingProduct f sr
prodVar sr x = mkProd sr (singletonProdMap sr (SR.occ_one sr) x)
prodMul :: Tm f => SemiRingProduct f sr -> SemiRingProduct f sr -> SemiRingProduct f sr
prodMul x y = mkProd sr m
where
sr = prodRepr x
mergeCommon (WrapF k) (_,a) (_,b) = Just (mkProdNote sr c k, c)
where c = SR.occ_add sr a b
m = AM.mergeWithKey mergeCommon id id (_prodMap x) (_prodMap y)
prodEval ::
(r -> r -> r) ->
(f (SR.SemiRingBase sr) -> r) ->
SemiRingProduct f sr ->
Maybe r
prodEval mul tm om =
runIdentity (prodEvalM (\x y -> Identity (mul x y)) (Identity . tm) om)
prodEvalM :: Monad m =>
(r -> r -> m r) ->
(f (SR.SemiRingBase sr) -> m r) ->
SemiRingProduct f sr ->
m (Maybe r)
prodEvalM mul tm om = f (AM.toList (_prodMap om))
where
sr = prodRepr om
f [] = return Nothing
f ((WrapF x, SR.occ_count sr -> n):xs)
| n == 0 = f xs
| otherwise =
do t <- tm x
t' <- go (n-1) t t
g xs t'
g [] z = return (Just z)
g ((WrapF x, SR.occ_count sr -> n):xs) z
| n == 0 = g xs z
| otherwise =
do t <- tm x
t' <- go n t z
g xs t'
go n t z
| n > 0 = go (n-1) t =<< mul z t
| otherwise = return z