module Control.Monad.Bayes.Enumerator (
Enumerator,
logExplicit,
explicit,
evidence,
mass,
compact,
enumerate,
expectation,
normalForm
) where
import Data.AEq (AEq, (===), (~==))
import Control.Applicative (Applicative, Alternative)
import Control.Monad (MonadPlus)
import Control.Arrow (second)
import qualified Data.Map as Map
import qualified Data.Vector.Generic as V
import Numeric.Log as Log
import Control.Monad.Trans.Writer
import Data.Monoid
import Data.Maybe
import Control.Monad.Bayes.Class
newtype Enumerator a = Enumerator (WriterT (Product (Log Double)) [] a)
deriving(Functor, Applicative, Monad, Alternative, MonadPlus)
instance MonadSample Enumerator where
random = error "Infinitely supported random variables not supported in Enumerator"
bernoulli p = fromList [(True, (Exp . log) p), (False, (Exp . log) (1-p))]
categorical v = fromList $ zip [0..] $ map (Exp . log) (V.toList v)
instance MonadCond Enumerator where
score w = fromList [((), w)]
instance MonadInfer Enumerator
fromList :: [(a, Log Double)] -> Enumerator a
fromList = Enumerator . WriterT . map (second Product)
logExplicit :: Enumerator a -> [(a, Log Double)]
logExplicit (Enumerator m) = map (second getProduct) $ runWriterT m
explicit :: Enumerator a -> [(a,Double)]
explicit = map (second (exp . ln)) . logExplicit
evidence :: Enumerator a -> Log Double
evidence = Log.sum . map snd . logExplicit
mass :: Ord a => Enumerator a -> a -> Double
mass d = f where
f a = fromMaybe 0 $ lookup a m
m = enumerate d
compact :: (Num r, Ord a) => [(a,r)] -> [(a,r)]
compact = Map.toAscList . Map.fromListWith (+)
enumerate :: Ord a => Enumerator a -> [(a, Double)]
enumerate d = compact (zip xs ws) where
(xs, ws) = second (map (exp . ln) . normalize) $ unzip (logExplicit d)
expectation :: (a -> Double) -> Enumerator a -> Double
expectation f = Prelude.sum . map (\(x, w) -> f x * (exp . ln) w) . normalizeWeights . logExplicit
normalize :: [Log Double] -> [Log Double]
normalize xs = map (/ z) xs where
z = Log.sum xs
normalizeWeights :: [(a, Log Double)] -> [(a, Log Double)]
normalizeWeights ls = zip xs ps where
(xs, ws) = unzip ls
ps = normalize ws
normalForm :: Ord a => Enumerator a -> [(a, Double)]
normalForm = filter ((/= 0) . snd) . compact . explicit
instance Ord a => Eq (Enumerator a) where
p == q = normalForm p == normalForm q
instance Ord a => AEq (Enumerator a) where
p === q = xs == ys && ps === qs where
(xs,ps) = unzip (normalForm p)
(ys,qs) = unzip (normalForm q)
p ~== q = xs == ys && ps ~== qs where
(xs,ps) = unzip $ filter (not . (~== 0) . snd) $ normalForm p
(ys,qs) = unzip $ filter (not . (~== 0) . snd) $ normalForm q