{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RecordWildCards #-}
module Control.Monad.Bayes.Inference.RMSMC
( rmsmc,
rmsmcDynamic,
rmsmcBasic,
)
where
import Control.Monad.Bayes.Class (MonadDistribution)
import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..))
import Control.Monad.Bayes.Inference.SMC
import Control.Monad.Bayes.Population
( Population,
spawn,
withParticles,
)
import Control.Monad.Bayes.Sequential.Coroutine as Seq
import Control.Monad.Bayes.Sequential.Coroutine qualified as S
import Control.Monad.Bayes.Traced.Basic qualified as TrBas
import Control.Monad.Bayes.Traced.Dynamic qualified as TrDyn
import Control.Monad.Bayes.Traced.Static as Tr
( Traced,
marginal,
mhStep,
)
import Control.Monad.Bayes.Traced.Static qualified as TrStat
import Data.Monoid (Endo (..))
rmsmc ::
MonadDistribution m =>
MCMCConfig ->
SMCConfig m ->
Sequential (Traced (Population m)) a ->
Population m a
rmsmc :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> Sequential (Traced (Population m)) a
-> Population m a
rmsmc (MCMCConfig {Int
Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
..}) (SMCConfig {Int
forall x. Population m x -> Population m x
numParticles :: forall (m :: * -> *). SMCConfig m -> Int
numSteps :: forall (m :: * -> *). SMCConfig m -> Int
resampler :: forall (m :: * -> *).
SMCConfig m -> forall x. Population m x -> Population m x
numParticles :: Int
numSteps :: Int
resampler :: forall x. Population m x -> Population m x
..}) =
forall (m :: * -> *) a. Monad m => Traced m a -> m a
marginal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> Sequential m a -> m a
S.sequentially (forall a. Int -> (a -> a) -> a -> a
composeCopies Int
numMCMCSteps forall (m :: * -> *) a.
MonadDistribution m =>
Traced m a -> Traced m a
mhStep forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
(forall x. m x -> m x) -> Traced m a -> Traced m a
TrStat.hoist forall x. Population m x -> Population m x
resampler) Int
numSteps
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
(forall x. m x -> m x) -> Sequential m a -> Sequential m a
S.hoistFirst (forall (m :: * -> *) a.
(forall x. m x -> m x) -> Traced m a -> Traced m a
TrStat.hoist (forall (m :: * -> *). Monad m => Int -> Population m ()
spawn Int
numParticles forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>>))
rmsmcBasic ::
MonadDistribution m =>
MCMCConfig ->
SMCConfig m ->
Sequential (TrBas.Traced (Population m)) a ->
Population m a
rmsmcBasic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> Sequential (Traced (Population m)) a
-> Population m a
rmsmcBasic (MCMCConfig {Int
Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
..}) (SMCConfig {Int
forall x. Population m x -> Population m x
numParticles :: Int
numSteps :: Int
resampler :: forall x. Population m x -> Population m x
numParticles :: forall (m :: * -> *). SMCConfig m -> Int
numSteps :: forall (m :: * -> *). SMCConfig m -> Int
resampler :: forall (m :: * -> *).
SMCConfig m -> forall x. Population m x -> Population m x
..}) =
forall (m :: * -> *) a. Monad m => Traced m a -> m a
TrBas.marginal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> Sequential m a -> m a
S.sequentially (forall a. Int -> (a -> a) -> a -> a
composeCopies Int
numMCMCSteps forall (m :: * -> *) a.
MonadDistribution m =>
Traced m a -> Traced m a
TrBas.mhStep forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
(forall x. m x -> m x) -> Traced m a -> Traced m a
TrBas.hoist forall x. Population m x -> Population m x
resampler) Int
numSteps
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
(forall x. m x -> m x) -> Sequential m a -> Sequential m a
S.hoistFirst (forall (m :: * -> *) a.
(forall x. m x -> m x) -> Traced m a -> Traced m a
TrBas.hoist (forall (m :: * -> *) a.
Monad m =>
Int -> Population m a -> Population m a
withParticles Int
numParticles))
rmsmcDynamic ::
MonadDistribution m =>
MCMCConfig ->
SMCConfig m ->
Sequential (TrDyn.Traced (Population m)) a ->
Population m a
rmsmcDynamic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> Sequential (Traced (Population m)) a
-> Population m a
rmsmcDynamic (MCMCConfig {Int
Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
..}) (SMCConfig {Int
forall x. Population m x -> Population m x
numParticles :: Int
numSteps :: Int
resampler :: forall x. Population m x -> Population m x
numParticles :: forall (m :: * -> *). SMCConfig m -> Int
numSteps :: forall (m :: * -> *). SMCConfig m -> Int
resampler :: forall (m :: * -> *).
SMCConfig m -> forall x. Population m x -> Population m x
..}) =
forall (m :: * -> *) a. Monad m => Traced m a -> m a
TrDyn.marginal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> Sequential m a -> m a
S.sequentially (forall (m :: * -> *) a. Monad m => Traced m a -> Traced m a
TrDyn.freeze forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> (a -> a) -> a -> a
composeCopies Int
numMCMCSteps forall (m :: * -> *) a.
MonadDistribution m =>
Traced m a -> Traced m a
TrDyn.mhStep forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
(forall x. m x -> m x) -> Traced m a -> Traced m a
TrDyn.hoist forall x. Population m x -> Population m x
resampler) Int
numSteps
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
(forall x. m x -> m x) -> Sequential m a -> Sequential m a
S.hoistFirst (forall (m :: * -> *) a.
(forall x. m x -> m x) -> Traced m a -> Traced m a
TrDyn.hoist (forall (m :: * -> *) a.
Monad m =>
Int -> Population m a -> Population m a
withParticles Int
numParticles))
composeCopies :: Int -> (a -> a) -> (a -> a)
composeCopies :: forall a. Int -> (a -> a) -> a -> a
composeCopies Int
k = forall a b. (Endo a -> Endo b) -> (a -> a) -> b -> b
withEndo (forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> a -> [a]
replicate Int
k)
withEndo :: (Endo a -> Endo b) -> (a -> a) -> b -> b
withEndo :: forall a b. (Endo a -> Endo b) -> (a -> a) -> b -> b
withEndo Endo a -> Endo b
f = forall a. Endo a -> a -> a
appEndo forall b c a. (b -> c) -> (a -> b) -> a -> c
. Endo a -> Endo b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> a) -> Endo a
Endo