{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- |
-- Module      : Control.Monad.Bayes.Inference.SMC2
-- Description : Sequential Monte Carlo squared (SMC²)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- Sequential Monte Carlo squared (SMC²) sampling.
--
-- Nicolas Chopin, Pierre E. Jacob, and Omiros Papaspiliopoulos. 2013. SMC²: an efficient algorithm for sequential analysis of state space models. /Journal of the Royal Statistical Society Series B: Statistical Methodology/ 75 (2013), 397-426. Issue 3. <https://doi.org/10.1111/j.1467-9868.2012.01046.x>
module Control.Monad.Bayes.Inference.SMC2
  ( smc2,
    SMC2,
  )
where

import Control.Monad.Bayes.Class
  ( MonadDistribution (random),
    MonadFactor (..),
    MonadMeasure,
  )
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush)
import Control.Monad.Bayes.Population as Pop (PopulationT, resampleMultinomial, runPopulationT)
import Control.Monad.Bayes.Sequential.Coroutine (SequentialT)
import Control.Monad.Bayes.Traced
import Control.Monad.Trans (MonadTrans (..))
import Numeric.Log (Log)

-- | Helper monad transformer for preprocessing the model for 'smc2'.
newtype SMC2 m a = SMC2 (SequentialT (TracedT (PopulationT m)) a)
  deriving newtype (forall a b. a -> SMC2 m b -> SMC2 m a
forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> SMC2 m b -> SMC2 m a
$c<$ :: forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
fmap :: forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
$cfmap :: forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
Functor, forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (m :: * -> *). Monad m => Functor (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
*> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
liftA2 :: forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
<*> :: forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
pure :: forall a. a -> SMC2 m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
Applicative, forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *). Monad m => Applicative (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> SMC2 m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
>> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
>>= :: forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
Monad)

setup :: SMC2 m a -> SequentialT (TracedT (PopulationT m)) a
setup :: forall (m :: * -> *) a.
SMC2 m a -> SequentialT (TracedT (PopulationT m)) a
setup (SMC2 SequentialT (TracedT (PopulationT m)) a
m) = SequentialT (TracedT (PopulationT m)) a
m

instance MonadTrans SMC2 where
  lift :: forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
lift = forall (m :: * -> *) a.
SequentialT (TracedT (PopulationT m)) a -> SMC2 m a
SMC2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance (MonadDistribution m) => MonadDistribution (SMC2 m) where
  random :: SMC2 m Double
random = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadDistribution m => m Double
random

instance (Monad m) => MonadFactor (SMC2 m) where
  score :: Log Double -> SMC2 m ()
score = forall (m :: * -> *) a.
SequentialT (TracedT (PopulationT m)) a -> SMC2 m a
SMC2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance (MonadDistribution m) => MonadMeasure (SMC2 m)

-- | Sequential Monte Carlo squared.
smc2 ::
  (MonadDistribution m) =>
  -- | number of time steps
  Int ->
  -- | number of inner particles
  Int ->
  -- | number of outer particles
  Int ->
  -- | number of MH transitions
  Int ->
  -- | model parameters
  SequentialT (TracedT (PopulationT m)) b ->
  -- | model
  (b -> SequentialT (PopulationT (SMC2 m)) a) ->
  PopulationT m [(a, Log Double)]
smc2 :: forall (m :: * -> *) b a.
MonadDistribution m =>
Int
-> Int
-> Int
-> Int
-> SequentialT (TracedT (PopulationT m)) b
-> (b -> SequentialT (PopulationT (SMC2 m)) a)
-> PopulationT m [(a, Log Double)]
smc2 Int
k Int
n Int
p Int
t SequentialT (TracedT (PopulationT m)) b
param b -> SequentialT (PopulationT (SMC2 m)) a
m =
  forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> SequentialT (TracedT (PopulationT m)) a
-> PopulationT m a
rmsmc
    MCMCConfig {numMCMCSteps :: Int
numMCMCSteps = Int
t, proposal :: Proposal
proposal = Proposal
SingleSiteMH, numBurnIn :: Int
numBurnIn = Int
0}
    SMCConfig {numParticles :: Int
numParticles = Int
p, numSteps :: Int
numSteps = Int
k, resampler :: forall x. PopulationT m x -> PopulationT m x
resampler = forall (m :: * -> *) a.
MonadDistribution m =>
PopulationT m a -> PopulationT m a
resampleMultinomial}
    (SequentialT (TracedT (PopulationT m)) b
param forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a.
SMC2 m a -> SequentialT (TracedT (PopulationT m)) a
setup forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadMeasure m =>
SMCConfig m -> SequentialT (PopulationT m) a -> PopulationT m a
smcPush (SMCConfig {numSteps :: Int
numSteps = Int
k, numParticles :: Int
numParticles = Int
n, resampler :: forall x. PopulationT (SMC2 m) x -> PopulationT (SMC2 m) x
resampler = forall (m :: * -> *) a.
MonadDistribution m =>
PopulationT m a -> PopulationT m a
resampleMultinomial}) forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> SequentialT (PopulationT (SMC2 m)) a
m)