-- |
-- Module      : Control.Monad.Bayes.Inference.SMC2
-- Description : Sequential Monte Carlo squared (SMC²)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- Sequential Monte Carlo squared (SMC²) sampling.
--
-- Nicolas Chopin, Pierre E. Jacob, and Omiros Papaspiliopoulos. 2013. SMC²: an efficient algorithm for sequential analysis of state space models. /Journal of the Royal Statistical Society Series B: Statistical Methodology/ 75 (2013), 397-426. Issue 3. <https://doi.org/10.1111/j.1467-9868.2012.01046.x>
module Control.Monad.Bayes.Inference.SMC2
  ( smc2,
  )
where

import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Helpers
import Control.Monad.Bayes.Inference.RMSMC
import Control.Monad.Bayes.Inference.SMC
import Control.Monad.Bayes.Population as Pop
import Control.Monad.Trans
import Numeric.Log

-- | Helper monad transformer for preprocessing the model for 'smc2'.
newtype SMC2 m a = SMC2 (S (T (P m)) a)
  deriving (a -> SMC2 m b -> SMC2 m a
(a -> b) -> SMC2 m a -> SMC2 m b
(forall a b. (a -> b) -> SMC2 m a -> SMC2 m b)
-> (forall a b. a -> SMC2 m b -> SMC2 m a) -> Functor (SMC2 m)
forall a b. a -> SMC2 m b -> SMC2 m a
forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SMC2 m b -> SMC2 m a
$c<$ :: forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
fmap :: (a -> b) -> SMC2 m a -> SMC2 m b
$cfmap :: forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
Functor, Functor (SMC2 m)
a -> SMC2 m a
Functor (SMC2 m) =>
(forall a. a -> SMC2 m a)
-> (forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b)
-> (forall a b c.
    (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a)
-> Applicative (SMC2 m)
SMC2 m a -> SMC2 m b -> SMC2 m b
SMC2 m a -> SMC2 m b -> SMC2 m a
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (m :: * -> *). Monad m => Functor (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: SMC2 m a -> SMC2 m b -> SMC2 m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
*> :: SMC2 m a -> SMC2 m b -> SMC2 m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
liftA2 :: (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
<*> :: SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
pure :: a -> SMC2 m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (SMC2 m)
Applicative, Applicative (SMC2 m)
a -> SMC2 m a
Applicative (SMC2 m) =>
(forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b)
-> (forall a. a -> SMC2 m a)
-> Monad (SMC2 m)
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
SMC2 m a -> SMC2 m b -> SMC2 m b
forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *). Monad m => Applicative (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> SMC2 m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
>> :: SMC2 m a -> SMC2 m b -> SMC2 m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
>>= :: SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (SMC2 m)
Monad)

setup :: SMC2 m a -> S (T (P m)) a
setup :: SMC2 m a -> S (T (P m)) a
setup (SMC2 m :: S (T (P m)) a
m) = S (T (P m)) a
m

instance MonadTrans SMC2 where
  lift :: m a -> SMC2 m a
lift = S (T (P m)) a -> SMC2 m a
forall (m :: * -> *) a. S (T (P m)) a -> SMC2 m a
SMC2 (S (T (P m)) a -> SMC2 m a)
-> (m a -> S (T (P m)) a) -> m a -> SMC2 m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Traced (P m) a -> S (T (P m)) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Traced (P m) a -> S (T (P m)) a)
-> (m a -> Traced (P m) a) -> m a -> S (T (P m)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m a -> Traced (P m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Population m a -> Traced (P m) a)
-> (m a -> Population m a) -> m a -> Traced (P m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> Population m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance MonadSample m => MonadSample (SMC2 m) where
  random :: SMC2 m Double
random = m Double -> SMC2 m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random

instance Monad m => MonadCond (SMC2 m) where
  score :: Log Double -> SMC2 m ()
score = S (T (P m)) () -> SMC2 m ()
forall (m :: * -> *) a. S (T (P m)) a -> SMC2 m a
SMC2 (S (T (P m)) () -> SMC2 m ())
-> (Log Double -> S (T (P m)) ()) -> Log Double -> SMC2 m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> S (T (P m)) ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score

instance MonadSample m => MonadInfer (SMC2 m)

-- | Sequential Monte Carlo squared.
smc2 ::
  MonadSample m =>
  -- | number of time steps
  Int ->
  -- | number of inner particles
  Int ->
  -- | number of outer particles
  Int ->
  -- | number of MH transitions
  Int ->
  -- | model parameters
  S (T (P m)) b ->
  -- | model
  (b -> S (P (SMC2 m)) a) ->
  P m [(a, Log Double)]
smc2 :: Int
-> Int
-> Int
-> Int
-> S (T (P m)) b
-> (b -> S (P (SMC2 m)) a)
-> P m [(a, Log Double)]
smc2 k :: Int
k n :: Int
n p :: Int
p t :: Int
t param :: S (T (P m)) b
param model :: b -> S (P (SMC2 m)) a
model =
  Int
-> Int
-> Int
-> Sequential (T (P m)) [(a, Log Double)]
-> P m [(a, Log Double)]
forall (m :: * -> *) a.
MonadSample m =>
Int
-> Int
-> Int
-> Sequential (Traced (Population m)) a
-> Population m a
rmsmc Int
k Int
p Int
t (S (T (P m)) b
param S (T (P m)) b
-> (b -> Sequential (T (P m)) [(a, Log Double)])
-> Sequential (T (P m)) [(a, Log Double)]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SMC2 m [(a, Log Double)] -> Sequential (T (P m)) [(a, Log Double)]
forall (m :: * -> *) a. SMC2 m a -> S (T (P m)) a
setup (SMC2 m [(a, Log Double)]
 -> Sequential (T (P m)) [(a, Log Double)])
-> (b -> SMC2 m [(a, Log Double)])
-> b
-> Sequential (T (P m)) [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population (SMC2 m) a -> SMC2 m [(a, Log Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Log Double)]
runPopulation (Population (SMC2 m) a -> SMC2 m [(a, Log Double)])
-> (b -> Population (SMC2 m) a) -> b -> SMC2 m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> S (P (SMC2 m)) a -> Population (SMC2 m) a
forall (m :: * -> *) a.
MonadInfer m =>
Int -> Int -> Sequential (Population m) a -> Population m a
smcSystematicPush Int
k Int
n (S (P (SMC2 m)) a -> Population (SMC2 m) a)
-> (b -> S (P (SMC2 m)) a) -> b -> Population (SMC2 m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> S (P (SMC2 m)) a
model)