-- |
-- Module      : Control.Monad.Bayes.Inference.SMC
-- Description : Sequential Monte Carlo (SMC)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- Sequential Monte Carlo (SMC) sampling.
--
-- Arnaud Doucet and Adam M. Johansen. 2011. A tutorial on particle filtering and smoothing: fifteen years later. In /The Oxford Handbook of Nonlinear Filtering/, Dan Crisan and Boris Rozovskii (Eds.). Oxford University Press, Chapter 8.
module Control.Monad.Bayes.Inference.SMC
  ( sir,
    smcMultinomial,
    smcSystematic,
    smcMultinomialPush,
    smcSystematicPush,
  )
where

import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Population
import Control.Monad.Bayes.Sequential as Seq

-- | Sequential importance resampling.
-- Basically an SMC template that takes a custom resampler.
sir ::
  Monad m =>
  -- | resampler
  (forall x. Population m x -> Population m x) ->
  -- | number of timesteps
  Int ->
  -- | population size
  Int ->
  -- | model
  Sequential (Population m) a ->
  Population m a
sir :: (forall x. Population m x -> Population m x)
-> Int -> Int -> Sequential (Population m) a -> Population m a
sir resampler :: forall x. Population m x -> Population m x
resampler k :: Int
k n :: Int
n = (forall x. Population m x -> Population m x)
-> Int -> Sequential (Population m) a -> Population m a
forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> Sequential m a -> m a
sis forall x. Population m x -> Population m x
resampler Int
k (Sequential (Population m) a -> Population m a)
-> (Sequential (Population m) a -> Sequential (Population m) a)
-> Sequential (Population m) a
-> Population m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. Population m x -> Population m x)
-> Sequential (Population m) a -> Sequential (Population m) a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> Sequential m a -> Sequential m a
Seq.hoistFirst (Int -> Population m ()
forall (m :: * -> *). Monad m => Int -> Population m ()
spawn Int
n Population m () -> Population m x -> Population m x
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>>)

-- | Sequential Monte Carlo with multinomial resampling at each timestep.
-- Weights are not normalized.
smcMultinomial ::
  MonadSample m =>
  -- | number of timesteps
  Int ->
  -- | number of particles
  Int ->
  -- | model
  Sequential (Population m) a ->
  Population m a
smcMultinomial :: Int -> Int -> Sequential (Population m) a -> Population m a
smcMultinomial = (forall x. Population m x -> Population m x)
-> Int -> Int -> Sequential (Population m) a -> Population m a
forall (m :: * -> *) a.
Monad m =>
(forall x. Population m x -> Population m x)
-> Int -> Int -> Sequential (Population m) a -> Population m a
sir forall x. Population m x -> Population m x
forall (m :: * -> *) a.
MonadSample m =>
Population m a -> Population m a
resampleMultinomial

-- | Sequential Monte Carlo with systematic resampling at each timestep.
-- Weights are not normalized.
smcSystematic ::
  MonadSample m =>
  -- | number of timesteps
  Int ->
  -- | number of particles
  Int ->
  -- | model
  Sequential (Population m) a ->
  Population m a
smcSystematic :: Int -> Int -> Sequential (Population m) a -> Population m a
smcSystematic = (forall x. Population m x -> Population m x)
-> Int -> Int -> Sequential (Population m) a -> Population m a
forall (m :: * -> *) a.
Monad m =>
(forall x. Population m x -> Population m x)
-> Int -> Int -> Sequential (Population m) a -> Population m a
sir forall x. Population m x -> Population m x
forall (m :: * -> *) a.
MonadSample m =>
Population m a -> Population m a
resampleSystematic

-- | Sequential Monte Carlo with multinomial resampling at each timestep.
-- Weights are normalized at each timestep and the total weight is pushed
-- as a score into the transformed monad.
smcMultinomialPush ::
  MonadInfer m =>
  -- | number of timesteps
  Int ->
  -- | number of particles
  Int ->
  -- | model
  Sequential (Population m) a ->
  Population m a
smcMultinomialPush :: Int -> Int -> Sequential (Population m) a -> Population m a
smcMultinomialPush = (forall x. Population m x -> Population m x)
-> Int -> Int -> Sequential (Population m) a -> Population m a
forall (m :: * -> *) a.
Monad m =>
(forall x. Population m x -> Population m x)
-> Int -> Int -> Sequential (Population m) a -> Population m a
sir (Population m x -> Population m x
forall (m :: * -> *) a.
MonadCond m =>
Population m a -> Population m a
pushEvidence (Population m x -> Population m x)
-> (Population m x -> Population m x)
-> Population m x
-> Population m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m x -> Population m x
forall (m :: * -> *) a.
MonadSample m =>
Population m a -> Population m a
resampleMultinomial)

-- | Sequential Monte Carlo with systematic resampling at each timestep.
-- Weights are normalized at each timestep and the total weight is pushed
-- as a score into the transformed monad.
smcSystematicPush ::
  MonadInfer m =>
  -- | number of timesteps
  Int ->
  -- | number of particles
  Int ->
  -- | model
  Sequential (Population m) a ->
  Population m a
smcSystematicPush :: Int -> Int -> Sequential (Population m) a -> Population m a
smcSystematicPush = (forall x. Population m x -> Population m x)
-> Int -> Int -> Sequential (Population m) a -> Population m a
forall (m :: * -> *) a.
Monad m =>
(forall x. Population m x -> Population m x)
-> Int -> Int -> Sequential (Population m) a -> Population m a
sir (Population m x -> Population m x
forall (m :: * -> *) a.
MonadCond m =>
Population m a -> Population m a
pushEvidence (Population m x -> Population m x)
-> (Population m x -> Population m x)
-> Population m x
-> Population m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m x -> Population m x
forall (m :: * -> *) a.
MonadSample m =>
Population m a -> Population m a
resampleSystematic)