{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RecordWildCards #-}

-- |
-- Module      : Control.Monad.Bayes.Inference.RMSMC
-- Description : Resample-Move Sequential Monte Carlo (RM-SMC)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- Resample-move Sequential Monte Carlo (RM-SMC) sampling.
--
-- Walter Gilks and Carlo Berzuini. 2001. Following a moving target - Monte Carlo inference for dynamic Bayesian models. /Journal of the Royal Statistical Society/ 63 (2001), 127-146. <http://www.mathcs.emory.edu/~whalen/Papers/BNs/MonteCarlo-DBNs.pdf>
module Control.Monad.Bayes.Inference.RMSMC
  ( rmsmc,
    rmsmcDynamic,
    rmsmcBasic,
  )
where

import Control.Monad.Bayes.Class (MonadDistribution)
import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..))
import Control.Monad.Bayes.Inference.SMC
import Control.Monad.Bayes.Population
  ( PopulationT,
    spawn,
    withParticles,
  )
import Control.Monad.Bayes.Sequential.Coroutine as Seq
import Control.Monad.Bayes.Sequential.Coroutine qualified as S
import Control.Monad.Bayes.Traced.Basic qualified as TrBas
import Control.Monad.Bayes.Traced.Dynamic qualified as TrDyn
import Control.Monad.Bayes.Traced.Static as Tr
  ( TracedT,
    marginal,
    mhStep,
  )
import Control.Monad.Bayes.Traced.Static qualified as TrStat
import Data.Monoid (Endo (..))

-- | Resample-move Sequential Monte Carlo.
rmsmc ::
  (MonadDistribution m) =>
  MCMCConfig ->
  SMCConfig m ->
  -- | model
  SequentialT (TracedT (PopulationT m)) a ->
  PopulationT m a
rmsmc :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> SequentialT (TracedT (PopulationT m)) a
-> PopulationT m a
rmsmc (MCMCConfig {Int
Proposal
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
..}) (SMCConfig {Int
forall x. PopulationT m x -> PopulationT m x
resampler :: forall x. PopulationT m x -> PopulationT m x
numSteps :: Int
numParticles :: Int
resampler :: forall (m :: * -> *).
SMCConfig m -> forall x. PopulationT m x -> PopulationT m x
numSteps :: forall (m :: * -> *). SMCConfig m -> Int
numParticles :: forall (m :: * -> *). SMCConfig m -> Int
..}) =
  TracedT (PopulationT m) a -> PopulationT m a
forall (m :: * -> *) a. Monad m => TracedT m a -> m a
marginal
    (TracedT (PopulationT m) a -> PopulationT m a)
-> (SequentialT (TracedT (PopulationT m)) a
    -> TracedT (PopulationT m) a)
-> SequentialT (TracedT (PopulationT m)) a
-> PopulationT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> Int
-> SequentialT (TracedT (PopulationT m)) a
-> TracedT (PopulationT m) a
forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> SequentialT m a -> m a
S.sequentially (Int
-> (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> TracedT (PopulationT m) x
-> TracedT (PopulationT m) x
forall a. Int -> (a -> a) -> a -> a
composeCopies Int
numMCMCSteps TracedT (PopulationT m) x -> TracedT (PopulationT m) x
forall (m :: * -> *) a.
MonadDistribution m =>
TracedT m a -> TracedT m a
mhStep (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> TracedT (PopulationT m) x
-> TracedT (PopulationT m) x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. PopulationT m x -> PopulationT m x)
-> TracedT (PopulationT m) x -> TracedT (PopulationT m) x
forall (m :: * -> *) a.
(forall x. m x -> m x) -> TracedT m a -> TracedT m a
TrStat.hoist PopulationT m x -> PopulationT m x
forall x. PopulationT m x -> PopulationT m x
resampler) Int
numSteps
    (SequentialT (TracedT (PopulationT m)) a
 -> TracedT (PopulationT m) a)
-> (SequentialT (TracedT (PopulationT m)) a
    -> SequentialT (TracedT (PopulationT m)) a)
-> SequentialT (TracedT (PopulationT m)) a
-> TracedT (PopulationT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> SequentialT (TracedT (PopulationT m)) a
-> SequentialT (TracedT (PopulationT m)) a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> SequentialT m a -> SequentialT m a
S.hoistFirst ((forall x. PopulationT m x -> PopulationT m x)
-> TracedT (PopulationT m) x -> TracedT (PopulationT m) x
forall (m :: * -> *) a.
(forall x. m x -> m x) -> TracedT m a -> TracedT m a
TrStat.hoist (Int -> PopulationT m ()
forall (m :: * -> *). Monad m => Int -> PopulationT m ()
spawn Int
numParticles PopulationT m () -> PopulationT m x -> PopulationT m x
forall a b. PopulationT m a -> PopulationT m b -> PopulationT m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>>))

-- | Resample-move Sequential Monte Carlo with a more efficient
-- tracing representation.
rmsmcBasic ::
  (MonadDistribution m) =>
  MCMCConfig ->
  SMCConfig m ->
  -- | model
  SequentialT (TrBas.TracedT (PopulationT m)) a ->
  PopulationT m a
rmsmcBasic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> SequentialT (TracedT (PopulationT m)) a
-> PopulationT m a
rmsmcBasic (MCMCConfig {Int
Proposal
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
..}) (SMCConfig {Int
forall x. PopulationT m x -> PopulationT m x
resampler :: forall (m :: * -> *).
SMCConfig m -> forall x. PopulationT m x -> PopulationT m x
numSteps :: forall (m :: * -> *). SMCConfig m -> Int
numParticles :: forall (m :: * -> *). SMCConfig m -> Int
resampler :: forall x. PopulationT m x -> PopulationT m x
numSteps :: Int
numParticles :: Int
..}) =
  TracedT (PopulationT m) a -> PopulationT m a
forall (m :: * -> *) a. Monad m => TracedT m a -> m a
TrBas.marginal
    (TracedT (PopulationT m) a -> PopulationT m a)
-> (SequentialT (TracedT (PopulationT m)) a
    -> TracedT (PopulationT m) a)
-> SequentialT (TracedT (PopulationT m)) a
-> PopulationT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> Int
-> SequentialT (TracedT (PopulationT m)) a
-> TracedT (PopulationT m) a
forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> SequentialT m a -> m a
S.sequentially (Int
-> (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> TracedT (PopulationT m) x
-> TracedT (PopulationT m) x
forall a. Int -> (a -> a) -> a -> a
composeCopies Int
numMCMCSteps TracedT (PopulationT m) x -> TracedT (PopulationT m) x
forall (m :: * -> *) a.
MonadDistribution m =>
TracedT m a -> TracedT m a
TrBas.mhStep (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> TracedT (PopulationT m) x
-> TracedT (PopulationT m) x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. PopulationT m x -> PopulationT m x)
-> TracedT (PopulationT m) x -> TracedT (PopulationT m) x
forall (m :: * -> *) a.
(forall x. m x -> m x) -> TracedT m a -> TracedT m a
TrBas.hoist PopulationT m x -> PopulationT m x
forall x. PopulationT m x -> PopulationT m x
resampler) Int
numSteps
    (SequentialT (TracedT (PopulationT m)) a
 -> TracedT (PopulationT m) a)
-> (SequentialT (TracedT (PopulationT m)) a
    -> SequentialT (TracedT (PopulationT m)) a)
-> SequentialT (TracedT (PopulationT m)) a
-> TracedT (PopulationT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> SequentialT (TracedT (PopulationT m)) a
-> SequentialT (TracedT (PopulationT m)) a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> SequentialT m a -> SequentialT m a
S.hoistFirst ((forall x. PopulationT m x -> PopulationT m x)
-> TracedT (PopulationT m) x -> TracedT (PopulationT m) x
forall (m :: * -> *) a.
(forall x. m x -> m x) -> TracedT m a -> TracedT m a
TrBas.hoist (Int -> PopulationT m x -> PopulationT m x
forall (m :: * -> *) a.
Monad m =>
Int -> PopulationT m a -> PopulationT m a
withParticles Int
numParticles))

-- | A variant of resample-move Sequential Monte Carlo
-- where only random variables since last resampling are considered
-- for rejuvenation.
rmsmcDynamic ::
  (MonadDistribution m) =>
  MCMCConfig ->
  SMCConfig m ->
  -- | model
  SequentialT (TrDyn.TracedT (PopulationT m)) a ->
  PopulationT m a
rmsmcDynamic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> SequentialT (TracedT (PopulationT m)) a
-> PopulationT m a
rmsmcDynamic (MCMCConfig {Int
Proposal
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
..}) (SMCConfig {Int
forall x. PopulationT m x -> PopulationT m x
resampler :: forall (m :: * -> *).
SMCConfig m -> forall x. PopulationT m x -> PopulationT m x
numSteps :: forall (m :: * -> *). SMCConfig m -> Int
numParticles :: forall (m :: * -> *). SMCConfig m -> Int
resampler :: forall x. PopulationT m x -> PopulationT m x
numSteps :: Int
numParticles :: Int
..}) =
  TracedT (PopulationT m) a -> PopulationT m a
forall (m :: * -> *) a. Monad m => TracedT m a -> m a
TrDyn.marginal
    (TracedT (PopulationT m) a -> PopulationT m a)
-> (SequentialT (TracedT (PopulationT m)) a
    -> TracedT (PopulationT m) a)
-> SequentialT (TracedT (PopulationT m)) a
-> PopulationT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> Int
-> SequentialT (TracedT (PopulationT m)) a
-> TracedT (PopulationT m) a
forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> SequentialT m a -> m a
S.sequentially (TracedT (PopulationT m) x -> TracedT (PopulationT m) x
forall (m :: * -> *) a. Monad m => TracedT m a -> TracedT m a
TrDyn.freeze (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> TracedT (PopulationT m) x
-> TracedT (PopulationT m) x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> TracedT (PopulationT m) x
-> TracedT (PopulationT m) x
forall a. Int -> (a -> a) -> a -> a
composeCopies Int
numMCMCSteps TracedT (PopulationT m) x -> TracedT (PopulationT m) x
forall (m :: * -> *) a.
MonadDistribution m =>
TracedT m a -> TracedT m a
TrDyn.mhStep (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> (TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> TracedT (PopulationT m) x
-> TracedT (PopulationT m) x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. PopulationT m x -> PopulationT m x)
-> TracedT (PopulationT m) x -> TracedT (PopulationT m) x
forall (m :: * -> *) a.
(forall x. m x -> m x) -> TracedT m a -> TracedT m a
TrDyn.hoist PopulationT m x -> PopulationT m x
forall x. PopulationT m x -> PopulationT m x
resampler) Int
numSteps
    (SequentialT (TracedT (PopulationT m)) a
 -> TracedT (PopulationT m) a)
-> (SequentialT (TracedT (PopulationT m)) a
    -> SequentialT (TracedT (PopulationT m)) a)
-> SequentialT (TracedT (PopulationT m)) a
-> TracedT (PopulationT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. TracedT (PopulationT m) x -> TracedT (PopulationT m) x)
-> SequentialT (TracedT (PopulationT m)) a
-> SequentialT (TracedT (PopulationT m)) a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> SequentialT m a -> SequentialT m a
S.hoistFirst ((forall x. PopulationT m x -> PopulationT m x)
-> TracedT (PopulationT m) x -> TracedT (PopulationT m) x
forall (m :: * -> *) a.
(forall x. m x -> m x) -> TracedT m a -> TracedT m a
TrDyn.hoist (Int -> PopulationT m x -> PopulationT m x
forall (m :: * -> *) a.
Monad m =>
Int -> PopulationT m a -> PopulationT m a
withParticles Int
numParticles))

-- | Apply a function a given number of times.
composeCopies :: Int -> (a -> a) -> (a -> a)
composeCopies :: forall a. Int -> (a -> a) -> a -> a
composeCopies Int
k = (Endo a -> Endo a) -> (a -> a) -> a -> a
forall a b. (Endo a -> Endo b) -> (a -> a) -> b -> b
withEndo ([Endo a] -> Endo a
forall a. Monoid a => [a] -> a
mconcat ([Endo a] -> Endo a) -> (Endo a -> [Endo a]) -> Endo a -> Endo a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Endo a -> [Endo a]
forall a. Int -> a -> [a]
replicate Int
k)

withEndo :: (Endo a -> Endo b) -> (a -> a) -> b -> b
withEndo :: forall a b. (Endo a -> Endo b) -> (a -> a) -> b -> b
withEndo Endo a -> Endo b
f = Endo b -> b -> b
forall a. Endo a -> a -> a
appEndo (Endo b -> b -> b) -> ((a -> a) -> Endo b) -> (a -> a) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Endo a -> Endo b
f (Endo a -> Endo b) -> ((a -> a) -> Endo a) -> (a -> a) -> Endo b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a) -> Endo a
forall a. (a -> a) -> Endo a
Endo