{-| Module : Control.Monad.Bayes.Weighted Description : Probability monad accumulating the likelihood Copyright : (c) Adam Scibior, 2015-2020 License : MIT Maintainer : leonhard.markert@tweag.io Stability : experimental Portability : GHC 'Weighted' is an instance of 'MonadCond'. Apply a 'MonadSample' transformer to obtain a 'MonadInfer' that can execute probabilistic models. -} module Control.Monad.Bayes.Weighted ( Weighted, withWeight, runWeighted, extractWeight, prior, flatten, applyWeight, hoist, ) where import Control.Monad.Trans import Control.Monad.Trans.State import Numeric.Log import Control.Monad.Bayes.Class -- | Execute the program using the prior distribution, while accumulating likelihood. newtype Weighted m a = Weighted (StateT (Log Double) m a) -- StateT is more efficient than WriterT deriving(Functor, Applicative, Monad, MonadIO, MonadTrans, MonadSample) instance Monad m => MonadCond (Weighted m) where score w = Weighted (modify (* w)) instance MonadSample m => MonadInfer (Weighted m) -- | Obtain an explicit value of the likelihood for a given value. runWeighted :: (Functor m) => Weighted m a -> m (a, Log Double) runWeighted (Weighted m) = runStateT m 1 -- | Compute the weight and discard the sample. extractWeight :: Functor m => Weighted m a -> m (Log Double) extractWeight m = snd <$> runWeighted m -- | Embed a random variable with explicitly given likelihood. -- -- > runWeighted . withWeight = id withWeight :: (Monad m) => m (a, Log Double) -> Weighted m a withWeight m = Weighted $ do (x,w) <- lift m modify (* w) return x -- | Discard the weight. -- This operation introduces bias. prior :: (Functor m) => Weighted m a -> m a prior = fmap fst . runWeighted -- | Combine weights from two different levels. flatten :: Monad m => Weighted (Weighted m) a -> Weighted m a flatten m = withWeight $ (\((x,p),q) -> (x, p*q)) <$> runWeighted (runWeighted m) -- | Use the weight as a factor in the transformed monad. applyWeight :: MonadCond m => Weighted m a -> m a applyWeight m = do (x, w) <- runWeighted m factor w return x -- | Apply a transformation to the transformed monad. hoist :: (forall x. m x -> n x) -> Weighted m a -> Weighted n a hoist t (Weighted m) = Weighted $ mapStateT t m