{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Control.Monad.Bayes.Inference.SMC2
( smc2,
SMC2,
)
where
import Control.Monad.Bayes.Class
( MonadDistribution (random),
MonadFactor (..),
MonadMeasure,
)
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush)
import Control.Monad.Bayes.Population as Pop (Population, population, resampleMultinomial)
import Control.Monad.Bayes.Sequential.Coroutine (Sequential)
import Control.Monad.Bayes.Traced
import Control.Monad.Trans (MonadTrans (..))
import Numeric.Log (Log)
newtype SMC2 m a = SMC2 (Sequential (Traced (Population m)) a)
deriving newtype (forall a b. a -> SMC2 m b -> SMC2 m a
forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> SMC2 m b -> SMC2 m a
$c<$ :: forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
fmap :: forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
$cfmap :: forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
Functor, forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (m :: * -> *). Monad m => Functor (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
*> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
liftA2 :: forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
<*> :: forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
pure :: forall a. a -> SMC2 m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
Applicative, forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *). Monad m => Applicative (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> SMC2 m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
>> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
>>= :: forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
Monad)
setup :: SMC2 m a -> Sequential (Traced (Population m)) a
setup :: forall (m :: * -> *) a.
SMC2 m a -> Sequential (Traced (Population m)) a
setup (SMC2 Sequential (Traced (Population m)) a
m) = Sequential (Traced (Population m)) a
m
instance MonadTrans SMC2 where
lift :: forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
lift = forall (m :: * -> *) a.
Sequential (Traced (Population m)) a -> SMC2 m a
SMC2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
instance MonadDistribution m => MonadDistribution (SMC2 m) where
random :: SMC2 m Double
random = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadDistribution m => m Double
random
instance Monad m => MonadFactor (SMC2 m) where
score :: Log Double -> SMC2 m ()
score = forall (m :: * -> *) a.
Sequential (Traced (Population m)) a -> SMC2 m a
SMC2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score
instance MonadDistribution m => MonadMeasure (SMC2 m)
smc2 ::
MonadDistribution m =>
Int ->
Int ->
Int ->
Int ->
Sequential (Traced (Population m)) b ->
(b -> Sequential (Population (SMC2 m)) a) ->
Population m [(a, Log Double)]
smc2 :: forall (m :: * -> *) b a.
MonadDistribution m =>
Int
-> Int
-> Int
-> Int
-> Sequential (Traced (Population m)) b
-> (b -> Sequential (Population (SMC2 m)) a)
-> Population m [(a, Log Double)]
smc2 Int
k Int
n Int
p Int
t Sequential (Traced (Population m)) b
param b -> Sequential (Population (SMC2 m)) a
m =
forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> Sequential (Traced (Population m)) a
-> Population m a
rmsmc
MCMCConfig {numMCMCSteps :: Int
numMCMCSteps = Int
t, proposal :: Proposal
proposal = Proposal
SingleSiteMH, numBurnIn :: Int
numBurnIn = Int
0}
SMCConfig {numParticles :: Int
numParticles = Int
p, numSteps :: Int
numSteps = Int
k, resampler :: forall x. Population m x -> Population m x
resampler = forall (m :: * -> *) a.
MonadDistribution m =>
Population m a -> Population m a
resampleMultinomial}
(Sequential (Traced (Population m)) b
param forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a.
SMC2 m a -> Sequential (Traced (Population m)) a
setup forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Population m a -> m [(a, Log Double)]
population forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadMeasure m =>
SMCConfig m -> Sequential (Population m) a -> Population m a
smcPush (SMCConfig {numSteps :: Int
numSteps = Int
k, numParticles :: Int
numParticles = Int
n, resampler :: forall x. Population (SMC2 m) x -> Population (SMC2 m) x
resampler = forall (m :: * -> *) a.
MonadDistribution m =>
Population m a -> Population m a
resampleMultinomial}) forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Sequential (Population (SMC2 m)) a
m)