{-# LANGUAGE RankNTypes #-}

-- |
-- Module      : Control.Monad.Bayes.Inference.PMMH
-- Description : Particle Marginal Metropolis-Hastings (PMMH)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- Particle Marginal Metropolis-Hastings (PMMH) sampling.
--
-- Christophe Andrieu, Arnaud Doucet, and Roman Holenstein. 2010. Particle Markov chain Monte Carlo Methods. /Journal of the Royal Statistical Society/ 72 (2010), 269-342. <http://www.stats.ox.ac.uk/~doucet/andrieu_doucet_holenstein_PMCMC.pdf>
module Control.Monad.Bayes.Inference.PMMH
  ( pmmh,
    pmmhBayesianModel,
  )
where

import Control.Monad.Bayes.Class (Bayesian (generative), MonadDistribution, MonadMeasure, prior)
import Control.Monad.Bayes.Inference.MCMC (MCMCConfig, mcmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (), smc)
import Control.Monad.Bayes.Population as Pop
  ( Population,
    hoist,
    population,
    pushEvidence,
  )
import Control.Monad.Bayes.Sequential.Coroutine (Sequential)
import Control.Monad.Bayes.Traced.Static (Traced)
import Control.Monad.Bayes.Weighted
import Control.Monad.Trans (lift)
import Numeric.Log (Log)

-- | Particle Marginal Metropolis-Hastings sampling.
pmmh ::
  MonadDistribution m =>
  MCMCConfig ->
  SMCConfig (Weighted m) ->
  Traced (Weighted m) a1 ->
  (a1 -> Sequential (Population (Weighted m)) a2) ->
  m [[(a2, Log Double)]]
pmmh :: forall (m :: * -> *) a1 a2.
MonadDistribution m =>
MCMCConfig
-> SMCConfig (Weighted m)
-> Traced (Weighted m) a1
-> (a1 -> Sequential (Population (Weighted m)) a2)
-> m [[(a2, Log Double)]]
pmmh MCMCConfig
mcmcConf SMCConfig (Weighted m)
smcConf Traced (Weighted m) a1
param a1 -> Sequential (Population (Weighted m)) a2
model =
  MCMCConfig
-> Traced (Weighted m) [(a2, Log Double)] -> m [[(a2, Log Double)]]
forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced (Weighted m) a -> m [a]
mcmc
    MCMCConfig
mcmcConf
    ( Traced (Weighted m) a1
param
        Traced (Weighted m) a1
-> (a1 -> Traced (Weighted m) [(a2, Log Double)])
-> Traced (Weighted m) [(a2, Log Double)]
forall a b.
Traced (Weighted m) a
-> (a -> Traced (Weighted m) b) -> Traced (Weighted m) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Population (Traced (Weighted m)) a2
-> Traced (Weighted m) [(a2, Log Double)]
forall (m :: * -> *) a. Population m a -> m [(a, Log Double)]
population
          (Population (Traced (Weighted m)) a2
 -> Traced (Weighted m) [(a2, Log Double)])
-> (a1 -> Population (Traced (Weighted m)) a2)
-> a1
-> Traced (Weighted m) [(a2, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population (Traced (Weighted m)) a2
-> Population (Traced (Weighted m)) a2
forall (m :: * -> *) a.
MonadFactor m =>
Population m a -> Population m a
pushEvidence
          (Population (Traced (Weighted m)) a2
 -> Population (Traced (Weighted m)) a2)
-> (a1 -> Population (Traced (Weighted m)) a2)
-> a1
-> Population (Traced (Weighted m)) a2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. Weighted m x -> Traced (Weighted m) x)
-> Population (Weighted m) a2
-> Population (Traced (Weighted m)) a2
forall (n :: * -> *) (m :: * -> *) a.
Monad n =>
(forall x. m x -> n x) -> Population m a -> Population n a
Pop.hoist Weighted m x -> Traced (Weighted m) x
forall x. Weighted m x -> Traced (Weighted m) x
forall (m :: * -> *) a. Monad m => m a -> Traced m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
          (Population (Weighted m) a2 -> Population (Traced (Weighted m)) a2)
-> (a1 -> Population (Weighted m) a2)
-> a1
-> Population (Traced (Weighted m)) a2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SMCConfig (Weighted m)
-> Sequential (Population (Weighted m)) a2
-> Population (Weighted m) a2
forall (m :: * -> *) a.
MonadDistribution m =>
SMCConfig m -> Sequential (Population m) a -> Population m a
smc SMCConfig (Weighted m)
smcConf
          (Sequential (Population (Weighted m)) a2
 -> Population (Weighted m) a2)
-> (a1 -> Sequential (Population (Weighted m)) a2)
-> a1
-> Population (Weighted m) a2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a1 -> Sequential (Population (Weighted m)) a2
model
    )

-- | Particle Marginal Metropolis-Hastings sampling from a Bayesian model
pmmhBayesianModel ::
  MonadMeasure m =>
  MCMCConfig ->
  SMCConfig (Weighted m) ->
  (forall m'. MonadMeasure m' => Bayesian m' a1 a2) ->
  m [[(a2, Log Double)]]
pmmhBayesianModel :: forall (m :: * -> *) a1 a2.
MonadMeasure m =>
MCMCConfig
-> SMCConfig (Weighted m)
-> (forall (m' :: * -> *). MonadMeasure m' => Bayesian m' a1 a2)
-> m [[(a2, Log Double)]]
pmmhBayesianModel MCMCConfig
mcmcConf SMCConfig (Weighted m)
smcConf forall (m' :: * -> *). MonadMeasure m' => Bayesian m' a1 a2
bm = MCMCConfig
-> SMCConfig (Weighted m)
-> Traced (Weighted m) a1
-> (a1 -> Sequential (Population (Weighted m)) a2)
-> m [[(a2, Log Double)]]
forall (m :: * -> *) a1 a2.
MonadDistribution m =>
MCMCConfig
-> SMCConfig (Weighted m)
-> Traced (Weighted m) a1
-> (a1 -> Sequential (Population (Weighted m)) a2)
-> m [[(a2, Log Double)]]
pmmh MCMCConfig
mcmcConf SMCConfig (Weighted m)
smcConf (Bayesian (Traced (Weighted m)) a1 a2 -> Traced (Weighted m) a1
forall (m :: * -> *) z o. Bayesian m z o -> m z
prior Bayesian (Traced (Weighted m)) a1 a2
forall (m' :: * -> *). MonadMeasure m' => Bayesian m' a1 a2
bm) (Bayesian (Sequential (Population (Weighted m))) a1 a2
-> a1 -> Sequential (Population (Weighted m)) a2
forall (m :: * -> *) z o. Bayesian m z o -> z -> m o
generative Bayesian (Sequential (Population (Weighted m))) a1 a2
forall (m' :: * -> *). MonadMeasure m' => Bayesian m' a1 a2
bm)