-- |
-- Module      : Control.Monad.Bayes.Enumerator
-- Description : Exhaustive enumeration of discrete random variables
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
module Control.Monad.Bayes.Enumerator
  ( Enumerator,
    logExplicit,
    explicit,
    evidence,
    mass,
    compact,
    enumerate,
    expectation,
    normalForm,
  )
where

import Control.Applicative (Alternative)
import Control.Arrow (second)
import Control.Monad (MonadPlus)
import Control.Monad.Bayes.Class
import Control.Monad.Trans.Writer
import Data.AEq ((===), AEq, (~==))
import qualified Data.Map as Map
import Data.Maybe
import Data.Monoid
import qualified Data.Vector.Generic as V
import Numeric.Log as Log

-- | An exact inference transformer that integrates
-- discrete random variables by enumerating all execution paths.
newtype Enumerator a = Enumerator (WriterT (Product (Log Double)) [] a)
  deriving (a -> Enumerator b -> Enumerator a
(a -> b) -> Enumerator a -> Enumerator b
(forall a b. (a -> b) -> Enumerator a -> Enumerator b)
-> (forall a b. a -> Enumerator b -> Enumerator a)
-> Functor Enumerator
forall a b. a -> Enumerator b -> Enumerator a
forall a b. (a -> b) -> Enumerator a -> Enumerator b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Enumerator b -> Enumerator a
$c<$ :: forall a b. a -> Enumerator b -> Enumerator a
fmap :: (a -> b) -> Enumerator a -> Enumerator b
$cfmap :: forall a b. (a -> b) -> Enumerator a -> Enumerator b
Functor, Functor Enumerator
a -> Enumerator a
Functor Enumerator =>
(forall a. a -> Enumerator a)
-> (forall a b.
    Enumerator (a -> b) -> Enumerator a -> Enumerator b)
-> (forall a b c.
    (a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c)
-> (forall a b. Enumerator a -> Enumerator b -> Enumerator b)
-> (forall a b. Enumerator a -> Enumerator b -> Enumerator a)
-> Applicative Enumerator
Enumerator a -> Enumerator b -> Enumerator b
Enumerator a -> Enumerator b -> Enumerator a
Enumerator (a -> b) -> Enumerator a -> Enumerator b
(a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c
forall a. a -> Enumerator a
forall a b. Enumerator a -> Enumerator b -> Enumerator a
forall a b. Enumerator a -> Enumerator b -> Enumerator b
forall a b. Enumerator (a -> b) -> Enumerator a -> Enumerator b
forall a b c.
(a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: Enumerator a -> Enumerator b -> Enumerator a
$c<* :: forall a b. Enumerator a -> Enumerator b -> Enumerator a
*> :: Enumerator a -> Enumerator b -> Enumerator b
$c*> :: forall a b. Enumerator a -> Enumerator b -> Enumerator b
liftA2 :: (a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c
$cliftA2 :: forall a b c.
(a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c
<*> :: Enumerator (a -> b) -> Enumerator a -> Enumerator b
$c<*> :: forall a b. Enumerator (a -> b) -> Enumerator a -> Enumerator b
pure :: a -> Enumerator a
$cpure :: forall a. a -> Enumerator a
$cp1Applicative :: Functor Enumerator
Applicative, Applicative Enumerator
a -> Enumerator a
Applicative Enumerator =>
(forall a b. Enumerator a -> (a -> Enumerator b) -> Enumerator b)
-> (forall a b. Enumerator a -> Enumerator b -> Enumerator b)
-> (forall a. a -> Enumerator a)
-> Monad Enumerator
Enumerator a -> (a -> Enumerator b) -> Enumerator b
Enumerator a -> Enumerator b -> Enumerator b
forall a. a -> Enumerator a
forall a b. Enumerator a -> Enumerator b -> Enumerator b
forall a b. Enumerator a -> (a -> Enumerator b) -> Enumerator b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> Enumerator a
$creturn :: forall a. a -> Enumerator a
>> :: Enumerator a -> Enumerator b -> Enumerator b
$c>> :: forall a b. Enumerator a -> Enumerator b -> Enumerator b
>>= :: Enumerator a -> (a -> Enumerator b) -> Enumerator b
$c>>= :: forall a b. Enumerator a -> (a -> Enumerator b) -> Enumerator b
$cp1Monad :: Applicative Enumerator
Monad, Applicative Enumerator
Enumerator a
Applicative Enumerator =>
(forall a. Enumerator a)
-> (forall a. Enumerator a -> Enumerator a -> Enumerator a)
-> (forall a. Enumerator a -> Enumerator [a])
-> (forall a. Enumerator a -> Enumerator [a])
-> Alternative Enumerator
Enumerator a -> Enumerator a -> Enumerator a
Enumerator a -> Enumerator [a]
Enumerator a -> Enumerator [a]
forall a. Enumerator a
forall a. Enumerator a -> Enumerator [a]
forall a. Enumerator a -> Enumerator a -> Enumerator a
forall (f :: * -> *).
Applicative f =>
(forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
many :: Enumerator a -> Enumerator [a]
$cmany :: forall a. Enumerator a -> Enumerator [a]
some :: Enumerator a -> Enumerator [a]
$csome :: forall a. Enumerator a -> Enumerator [a]
<|> :: Enumerator a -> Enumerator a -> Enumerator a
$c<|> :: forall a. Enumerator a -> Enumerator a -> Enumerator a
empty :: Enumerator a
$cempty :: forall a. Enumerator a
$cp1Alternative :: Applicative Enumerator
Alternative, Monad Enumerator
Alternative Enumerator
Enumerator a
(Alternative Enumerator, Monad Enumerator) =>
(forall a. Enumerator a)
-> (forall a. Enumerator a -> Enumerator a -> Enumerator a)
-> MonadPlus Enumerator
Enumerator a -> Enumerator a -> Enumerator a
forall a. Enumerator a
forall a. Enumerator a -> Enumerator a -> Enumerator a
forall (m :: * -> *).
(Alternative m, Monad m) =>
(forall a. m a) -> (forall a. m a -> m a -> m a) -> MonadPlus m
mplus :: Enumerator a -> Enumerator a -> Enumerator a
$cmplus :: forall a. Enumerator a -> Enumerator a -> Enumerator a
mzero :: Enumerator a
$cmzero :: forall a. Enumerator a
$cp2MonadPlus :: Monad Enumerator
$cp1MonadPlus :: Alternative Enumerator
MonadPlus)

instance MonadSample Enumerator where
  random :: Enumerator Double
random = [Char] -> Enumerator Double
forall a. HasCallStack => [Char] -> a
error "Infinitely supported random variables not supported in Enumerator"
  bernoulli :: Double -> Enumerator Bool
bernoulli p :: Double
p = [(Bool, Log Double)] -> Enumerator Bool
forall a. [(a, Log Double)] -> Enumerator a
fromList [(Bool
True, (Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double)
-> (Double -> Double) -> Double -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
log) Double
p), (Bool
False, (Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double)
-> (Double -> Double) -> Double -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
log) (1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
p))]
  categorical :: v Double -> Enumerator Int
categorical v :: v Double
v = [(Int, Log Double)] -> Enumerator Int
forall a. [(a, Log Double)] -> Enumerator a
fromList ([(Int, Log Double)] -> Enumerator Int)
-> [(Int, Log Double)] -> Enumerator Int
forall a b. (a -> b) -> a -> b
$ [Int] -> [Log Double] -> [(Int, Log Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [0 ..] ([Log Double] -> [(Int, Log Double)])
-> [Log Double] -> [(Int, Log Double)]
forall a b. (a -> b) -> a -> b
$ (Double -> Log Double) -> [Double] -> [Log Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double)
-> (Double -> Double) -> Double -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
log) (v Double -> [Double]
forall (v :: * -> *) a. Vector v a => v a -> [a]
V.toList v Double
v)

instance MonadCond Enumerator where
  score :: Log Double -> Enumerator ()
score w :: Log Double
w = [((), Log Double)] -> Enumerator ()
forall a. [(a, Log Double)] -> Enumerator a
fromList [((), Log Double
w)]

instance MonadInfer Enumerator

-- | Construct Enumerator from a list of values and associated weights.
fromList :: [(a, Log Double)] -> Enumerator a
fromList :: [(a, Log Double)] -> Enumerator a
fromList = WriterT (Product (Log Double)) [] a -> Enumerator a
forall a. WriterT (Product (Log Double)) [] a -> Enumerator a
Enumerator (WriterT (Product (Log Double)) [] a -> Enumerator a)
-> ([(a, Log Double)] -> WriterT (Product (Log Double)) [] a)
-> [(a, Log Double)]
-> Enumerator a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Product (Log Double))] -> WriterT (Product (Log Double)) [] a
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterT ([(a, Product (Log Double))]
 -> WriterT (Product (Log Double)) [] a)
-> ([(a, Log Double)] -> [(a, Product (Log Double))])
-> [(a, Log Double)]
-> WriterT (Product (Log Double)) [] a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Log Double) -> (a, Product (Log Double)))
-> [(a, Log Double)] -> [(a, Product (Log Double))]
forall a b. (a -> b) -> [a] -> [b]
map ((Log Double -> Product (Log Double))
-> (a, Log Double) -> (a, Product (Log Double))
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second Log Double -> Product (Log Double)
forall a. a -> Product a
Product)

-- | Returns the posterior as a list of weight-value pairs without any post-processing,
-- such as normalization or aggregation
logExplicit :: Enumerator a -> [(a, Log Double)]
logExplicit :: Enumerator a -> [(a, Log Double)]
logExplicit (Enumerator m :: WriterT (Product (Log Double)) [] a
m) = ((a, Product (Log Double)) -> (a, Log Double))
-> [(a, Product (Log Double))] -> [(a, Log Double)]
forall a b. (a -> b) -> [a] -> [b]
map ((Product (Log Double) -> Log Double)
-> (a, Product (Log Double)) -> (a, Log Double)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second Product (Log Double) -> Log Double
forall a. Product a -> a
getProduct) ([(a, Product (Log Double))] -> [(a, Log Double)])
-> [(a, Product (Log Double))] -> [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ WriterT (Product (Log Double)) [] a -> [(a, Product (Log Double))]
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT (Product (Log Double)) [] a
m

-- | Same as `toList`, only weights are converted from log-domain.
explicit :: Enumerator a -> [(a, Double)]
explicit :: Enumerator a -> [(a, Double)]
explicit = ((a, Log Double) -> (a, Double))
-> [(a, Log Double)] -> [(a, Double)]
forall a b. (a -> b) -> [a] -> [b]
map ((Log Double -> Double) -> (a, Log Double) -> (a, Double)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln)) ([(a, Log Double)] -> [(a, Double)])
-> (Enumerator a -> [(a, Log Double)])
-> Enumerator a
-> [(a, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Enumerator a -> [(a, Log Double)]
forall a. Enumerator a -> [(a, Log Double)]
logExplicit

-- | Returns the model evidence, that is sum of all weights.
evidence :: Enumerator a -> Log Double
evidence :: Enumerator a -> Log Double
evidence = [Log Double] -> Log Double
forall a (f :: * -> *).
(RealFloat a, Foldable f) =>
f (Log a) -> Log a
Log.sum ([Log Double] -> Log Double)
-> (Enumerator a -> [Log Double]) -> Enumerator a -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Log Double) -> Log Double)
-> [(a, Log Double)] -> [Log Double]
forall a b. (a -> b) -> [a] -> [b]
map (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd ([(a, Log Double)] -> [Log Double])
-> (Enumerator a -> [(a, Log Double)])
-> Enumerator a
-> [Log Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Enumerator a -> [(a, Log Double)]
forall a. Enumerator a -> [(a, Log Double)]
logExplicit

-- | Normalized probability mass of a specific value.
mass :: Ord a => Enumerator a -> a -> Double
mass :: Enumerator a -> a -> Double
mass d :: Enumerator a
d = a -> Double
f
  where
    f :: a -> Double
f a :: a
a = Double -> Maybe Double -> Double
forall a. a -> Maybe a -> a
fromMaybe 0 (Maybe Double -> Double) -> Maybe Double -> Double
forall a b. (a -> b) -> a -> b
$ a -> [(a, Double)] -> Maybe Double
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup a
a [(a, Double)]
m
    m :: [(a, Double)]
m = Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
enumerate Enumerator a
d

-- | Aggregate weights of equal values.
-- The resulting list is sorted ascendingly according to values.
compact :: (Num r, Ord a) => [(a, r)] -> [(a, r)]
compact :: [(a, r)] -> [(a, r)]
compact = Map a r -> [(a, r)]
forall k a. Map k a -> [(k, a)]
Map.toAscList (Map a r -> [(a, r)])
-> ([(a, r)] -> Map a r) -> [(a, r)] -> [(a, r)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (r -> r -> r) -> [(a, r)] -> Map a r
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
Map.fromListWith r -> r -> r
forall a. Num a => a -> a -> a
(+)

-- | Aggregate and normalize of weights.
-- The resulting list is sorted ascendingly according to values.
--
-- > enumerate = compact . explicit
enumerate :: Ord a => Enumerator a -> [(a, Double)]
enumerate :: Enumerator a -> [(a, Double)]
enumerate d :: Enumerator a
d = [(a, Double)] -> [(a, Double)]
forall r a. (Num r, Ord a) => [(a, r)] -> [(a, r)]
compact ([a] -> [Double] -> [(a, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs [Double]
ws)
  where
    (xs :: [a]
xs, ws :: [Double]
ws) = ([Log Double] -> [Double])
-> ([a], [Log Double]) -> ([a], [Double])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((Log Double -> Double) -> [Log Double] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln) ([Log Double] -> [Double])
-> ([Log Double] -> [Log Double]) -> [Log Double] -> [Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Log Double] -> [Log Double]
normalize) (([a], [Log Double]) -> ([a], [Double]))
-> ([a], [Log Double]) -> ([a], [Double])
forall a b. (a -> b) -> a -> b
$ [(a, Log Double)] -> ([a], [Log Double])
forall a b. [(a, b)] -> ([a], [b])
unzip (Enumerator a -> [(a, Log Double)]
forall a. Enumerator a -> [(a, Log Double)]
logExplicit Enumerator a
d)

-- | Expectation of a given function computed using normalized weights.
expectation :: (a -> Double) -> Enumerator a -> Double
expectation :: (a -> Double) -> Enumerator a -> Double
expectation f :: a -> Double
f = [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Prelude.sum ([Double] -> Double)
-> (Enumerator a -> [Double]) -> Enumerator a -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Log Double) -> Double) -> [(a, Log Double)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (\(x :: a
x, w :: Log Double
w) -> a -> Double
f a
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln) Log Double
w) ([(a, Log Double)] -> [Double])
-> (Enumerator a -> [(a, Log Double)]) -> Enumerator a -> [Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Log Double)] -> [(a, Log Double)]
forall a. [(a, Log Double)] -> [(a, Log Double)]
normalizeWeights ([(a, Log Double)] -> [(a, Log Double)])
-> (Enumerator a -> [(a, Log Double)])
-> Enumerator a
-> [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Enumerator a -> [(a, Log Double)]
forall a. Enumerator a -> [(a, Log Double)]
logExplicit

normalize :: [Log Double] -> [Log Double]
normalize :: [Log Double] -> [Log Double]
normalize xs :: [Log Double]
xs = (Log Double -> Log Double) -> [Log Double] -> [Log Double]
forall a b. (a -> b) -> [a] -> [b]
map (Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Log Double
z) [Log Double]
xs
  where
    z :: Log Double
z = [Log Double] -> Log Double
forall a (f :: * -> *).
(RealFloat a, Foldable f) =>
f (Log a) -> Log a
Log.sum [Log Double]
xs

-- | Divide all weights by their sum.
normalizeWeights :: [(a, Log Double)] -> [(a, Log Double)]
normalizeWeights :: [(a, Log Double)] -> [(a, Log Double)]
normalizeWeights ls :: [(a, Log Double)]
ls = [a] -> [Log Double] -> [(a, Log Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs [Log Double]
ps
  where
    (xs :: [a]
xs, ws :: [Log Double]
ws) = [(a, Log Double)] -> ([a], [Log Double])
forall a b. [(a, b)] -> ([a], [b])
unzip [(a, Log Double)]
ls
    ps :: [Log Double]
ps = [Log Double] -> [Log Double]
normalize [Log Double]
ws

-- | 'compact' followed by removing values with zero weight.
normalForm :: Ord a => Enumerator a -> [(a, Double)]
normalForm :: Enumerator a -> [(a, Double)]
normalForm = ((a, Double) -> Bool) -> [(a, Double)] -> [(a, Double)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= 0) (Double -> Bool) -> ((a, Double) -> Double) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Double) -> Double
forall a b. (a, b) -> b
snd) ([(a, Double)] -> [(a, Double)])
-> (Enumerator a -> [(a, Double)]) -> Enumerator a -> [(a, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Double)] -> [(a, Double)]
forall r a. (Num r, Ord a) => [(a, r)] -> [(a, r)]
compact ([(a, Double)] -> [(a, Double)])
-> (Enumerator a -> [(a, Double)]) -> Enumerator a -> [(a, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Enumerator a -> [(a, Double)]
forall a. Enumerator a -> [(a, Double)]
explicit

instance Ord a => Eq (Enumerator a) where
  p :: Enumerator a
p == :: Enumerator a -> Enumerator a -> Bool
== q :: Enumerator a
q = Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
p [(a, Double)] -> [(a, Double)] -> Bool
forall a. Eq a => a -> a -> Bool
== Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
q

instance Ord a => AEq (Enumerator a) where
  p :: Enumerator a
p === :: Enumerator a -> Enumerator a -> Bool
=== q :: Enumerator a
q = [a]
xs [a] -> [a] -> Bool
forall a. Eq a => a -> a -> Bool
== [a]
ys Bool -> Bool -> Bool
&& [Double]
ps [Double] -> [Double] -> Bool
forall a. AEq a => a -> a -> Bool
=== [Double]
qs
    where
      (xs :: [a]
xs, ps :: [Double]
ps) = [(a, Double)] -> ([a], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip (Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
p)
      (ys :: [a]
ys, qs :: [Double]
qs) = [(a, Double)] -> ([a], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip (Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
q)
  p :: Enumerator a
p ~== :: Enumerator a -> Enumerator a -> Bool
~== q :: Enumerator a
q = [a]
xs [a] -> [a] -> Bool
forall a. Eq a => a -> a -> Bool
== [a]
ys Bool -> Bool -> Bool
&& [Double]
ps [Double] -> [Double] -> Bool
forall a. AEq a => a -> a -> Bool
~== [Double]
qs
    where
      (xs :: [a]
xs, ps :: [Double]
ps) = [(a, Double)] -> ([a], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(a, Double)] -> ([a], [Double]))
-> [(a, Double)] -> ([a], [Double])
forall a b. (a -> b) -> a -> b
$ ((a, Double) -> Bool) -> [(a, Double)] -> [(a, Double)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ((a, Double) -> Bool) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Bool
forall a. AEq a => a -> a -> Bool
~== 0) (Double -> Bool) -> ((a, Double) -> Double) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Double) -> Double
forall a b. (a, b) -> b
snd) ([(a, Double)] -> [(a, Double)]) -> [(a, Double)] -> [(a, Double)]
forall a b. (a -> b) -> a -> b
$ Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
p
      (ys :: [a]
ys, qs :: [Double]
qs) = [(a, Double)] -> ([a], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(a, Double)] -> ([a], [Double]))
-> [(a, Double)] -> ([a], [Double])
forall a b. (a -> b) -> a -> b
$ ((a, Double) -> Bool) -> [(a, Double)] -> [(a, Double)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ((a, Double) -> Bool) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Bool
forall a. AEq a => a -> a -> Bool
~== 0) (Double -> Bool) -> ((a, Double) -> Double) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Double) -> Double
forall a b. (a, b) -> b
snd) ([(a, Double)] -> [(a, Double)]) -> [(a, Double)] -> [(a, Double)]
forall a b. (a -> b) -> a -> b
$ Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
q