module Control.Monad.Bayes.Weighted (
Weighted,
withWeight,
runWeighted,
extractWeight,
prior,
flatten,
applyWeight,
hoist,
) where
import Control.Monad.Trans
import Control.Monad.Trans.State
import Numeric.Log
import Control.Monad.Bayes.Class
newtype Weighted m a = Weighted (StateT (Log Double) m a)
deriving(Functor, Applicative, Monad, MonadIO, MonadTrans, MonadSample)
instance Monad m => MonadCond (Weighted m) where
score w = Weighted (modify (* w))
instance MonadSample m => MonadInfer (Weighted m)
runWeighted :: (Functor m) => Weighted m a -> m (a, Log Double)
runWeighted (Weighted m) = runStateT m 1
extractWeight :: Functor m => Weighted m a -> m (Log Double)
extractWeight m = snd <$> runWeighted m
withWeight :: (Monad m) => m (a, Log Double) -> Weighted m a
withWeight m = Weighted $ do
(x,w) <- lift m
modify (* w)
return x
prior :: (Functor m) => Weighted m a -> m a
prior = fmap fst . runWeighted
flatten :: Monad m => Weighted (Weighted m) a -> Weighted m a
flatten m = withWeight $ (\((x,p),q) -> (x, p*q)) <$> runWeighted (runWeighted m)
applyWeight :: MonadCond m => Weighted m a -> m a
applyWeight m = do
(x, w) <- runWeighted m
factor w
return x
hoist :: (forall x. m x -> n x) -> Weighted m a -> Weighted n a
hoist t (Weighted m) = Weighted $ mapStateT t m