{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- |
-- Module      : Control.Monad.Bayes.Inference.SMC2
-- Description : Sequential Monte Carlo squared (SMC²)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- Sequential Monte Carlo squared (SMC²) sampling.
--
-- Nicolas Chopin, Pierre E. Jacob, and Omiros Papaspiliopoulos. 2013. SMC²: an efficient algorithm for sequential analysis of state space models. /Journal of the Royal Statistical Society Series B: Statistical Methodology/ 75 (2013), 397-426. Issue 3. <https://doi.org/10.1111/j.1467-9868.2012.01046.x>
module Control.Monad.Bayes.Inference.SMC2
  ( smc2,
    SMC2,
  )
where

import Control.Monad.Bayes.Class
  ( MonadDistribution (random),
    MonadFactor (..),
    MonadMeasure,
  )
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush)
import Control.Monad.Bayes.Population as Pop (Population, population, resampleMultinomial)
import Control.Monad.Bayes.Sequential.Coroutine (Sequential)
import Control.Monad.Bayes.Traced
import Control.Monad.Trans (MonadTrans (..))
import Numeric.Log (Log)

-- | Helper monad transformer for preprocessing the model for 'smc2'.
newtype SMC2 m a = SMC2 (Sequential (Traced (Population m)) a)
  deriving newtype (forall a b. a -> SMC2 m b -> SMC2 m a
forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> SMC2 m b -> SMC2 m a
$c<$ :: forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
fmap :: forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
$cfmap :: forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
Functor, forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (m :: * -> *). Monad m => Functor (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
*> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
liftA2 :: forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
<*> :: forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
pure :: forall a. a -> SMC2 m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
Applicative, forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *). Monad m => Applicative (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> SMC2 m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
>> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
>>= :: forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
Monad)

setup :: SMC2 m a -> Sequential (Traced (Population m)) a
setup :: forall (m :: * -> *) a.
SMC2 m a -> Sequential (Traced (Population m)) a
setup (SMC2 Sequential (Traced (Population m)) a
m) = Sequential (Traced (Population m)) a
m

instance MonadTrans SMC2 where
  lift :: forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
lift = forall (m :: * -> *) a.
Sequential (Traced (Population m)) a -> SMC2 m a
SMC2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance MonadDistribution m => MonadDistribution (SMC2 m) where
  random :: SMC2 m Double
random = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadDistribution m => m Double
random

instance Monad m => MonadFactor (SMC2 m) where
  score :: Log Double -> SMC2 m ()
score = forall (m :: * -> *) a.
Sequential (Traced (Population m)) a -> SMC2 m a
SMC2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance MonadDistribution m => MonadMeasure (SMC2 m)

-- | Sequential Monte Carlo squared.
smc2 ::
  MonadDistribution m =>
  -- | number of time steps
  Int ->
  -- | number of inner particles
  Int ->
  -- | number of outer particles
  Int ->
  -- | number of MH transitions
  Int ->
  -- | model parameters
  Sequential (Traced (Population m)) b ->
  -- | model
  (b -> Sequential (Population (SMC2 m)) a) ->
  Population m [(a, Log Double)]
smc2 :: forall (m :: * -> *) b a.
MonadDistribution m =>
Int
-> Int
-> Int
-> Int
-> Sequential (Traced (Population m)) b
-> (b -> Sequential (Population (SMC2 m)) a)
-> Population m [(a, Log Double)]
smc2 Int
k Int
n Int
p Int
t Sequential (Traced (Population m)) b
param b -> Sequential (Population (SMC2 m)) a
m =
  forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> Sequential (Traced (Population m)) a
-> Population m a
rmsmc
    MCMCConfig {numMCMCSteps :: Int
numMCMCSteps = Int
t, proposal :: Proposal
proposal = Proposal
SingleSiteMH, numBurnIn :: Int
numBurnIn = Int
0}
    SMCConfig {numParticles :: Int
numParticles = Int
p, numSteps :: Int
numSteps = Int
k, resampler :: forall x. Population m x -> Population m x
resampler = forall (m :: * -> *) a.
MonadDistribution m =>
Population m a -> Population m a
resampleMultinomial}
    (Sequential (Traced (Population m)) b
param forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a.
SMC2 m a -> Sequential (Traced (Population m)) a
setup forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Population m a -> m [(a, Log Double)]
population forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadMeasure m =>
SMCConfig m -> Sequential (Population m) a -> Population m a
smcPush (SMCConfig {numSteps :: Int
numSteps = Int
k, numParticles :: Int
numParticles = Int
n, resampler :: forall x. Population (SMC2 m) x -> Population (SMC2 m) x
resampler = forall (m :: * -> *) a.
MonadDistribution m =>
Population m a -> Population m a
resampleMultinomial}) forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Sequential (Population (SMC2 m)) a
m)