{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
module Control.Monad.Bayes.Inference.MCMC where
import Control.Monad.Bayes.Class (MonadDistribution)
import Control.Monad.Bayes.Traced.Basic qualified as Basic
import Control.Monad.Bayes.Traced.Common
( MHResult (MHResult, trace),
Trace (probDensity),
burnIn,
mhTransWithBool,
)
import Control.Monad.Bayes.Traced.Dynamic qualified as Dynamic
import Control.Monad.Bayes.Traced.Static qualified as Static
import Control.Monad.Bayes.Weighted (Weighted, unweighted)
import Pipes ((>->))
import Pipes qualified as P
import Pipes.Prelude qualified as P
data Proposal = SingleSiteMH
data MCMCConfig = MCMCConfig {MCMCConfig -> Proposal
proposal :: Proposal, MCMCConfig -> Int
numMCMCSteps :: Int, MCMCConfig -> Int
numBurnIn :: Int}
defaultMCMCConfig :: MCMCConfig
defaultMCMCConfig :: MCMCConfig
defaultMCMCConfig = MCMCConfig {proposal :: Proposal
proposal = Proposal
SingleSiteMH, numMCMCSteps :: Int
numMCMCSteps = Int
1, numBurnIn :: Int
numBurnIn = Int
0}
mcmc :: MonadDistribution m => MCMCConfig -> Static.Traced (Weighted m) a -> m [a]
mcmc :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced (Weighted m) a -> m [a]
mcmc (MCMCConfig {Int
Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
..}) Traced (Weighted m) a
m = forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Functor m => Weighted m a -> m a
unweighted forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadDistribution m =>
Int -> Traced m a -> m [a]
Static.mh Int
numMCMCSteps Traced (Weighted m) a
m
mcmcBasic :: MonadDistribution m => MCMCConfig -> Basic.Traced (Weighted m) a -> m [a]
mcmcBasic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced (Weighted m) a -> m [a]
mcmcBasic (MCMCConfig {Int
Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
..}) Traced (Weighted m) a
m = forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Functor m => Weighted m a -> m a
unweighted forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadDistribution m =>
Int -> Traced m a -> m [a]
Basic.mh Int
numMCMCSteps Traced (Weighted m) a
m
mcmcDynamic :: MonadDistribution m => MCMCConfig -> Dynamic.Traced (Weighted m) a -> m [a]
mcmcDynamic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced (Weighted m) a -> m [a]
mcmcDynamic (MCMCConfig {Int
Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
..}) Traced (Weighted m) a
m = forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Functor m => Weighted m a -> m a
unweighted forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadDistribution m =>
Int -> Traced m a -> m [a]
Dynamic.mh Int
numMCMCSteps Traced (Weighted m) a
m
independentSamples :: Monad m => Static.Traced m a -> P.Producer (MHResult a) m (Trace a)
independentSamples :: forall (m :: * -> *) a.
Monad m =>
Traced m a -> Producer (MHResult a) m (Trace a)
independentSamples (Static.Traced Weighted (Density m) a
_w m (Trace a)
d) =
forall (m :: * -> *) a x' x r.
Monad m =>
m a -> Proxy x' x () a m r
P.repeatM m (Trace a)
d
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a. Functor m => (a -> Bool) -> Pipe a a m a
P.takeWhile' ((forall a. Eq a => a -> a -> Bool
== Log Double
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Trace a -> Log Double
probDensity)
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a b r. Functor m => (a -> b) -> Pipe a b m r
P.map (forall a. Bool -> Trace a -> MHResult a
MHResult Bool
False)
mcmcP :: MonadDistribution m => MCMCConfig -> Static.Traced m a -> P.Producer (MHResult a) m ()
mcmcP :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced m a -> Producer (MHResult a) m ()
mcmcP MCMCConfig {Int
Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
..} m :: Traced m a
m@(Static.Traced Weighted (Density m) a
w m (Trace a)
_) = do
Trace a
initialValue <- forall (m :: * -> *) a.
Monad m =>
Traced m a -> Producer (MHResult a) m (Trace a)
independentSamples Traced m a
m forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a r. Functor m => Consumer' a m r
P.drain
( forall (m :: * -> *) s r a.
Monad m =>
(s -> m (Either r (a, s))) -> s -> Producer a m r
P.unfoldr (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\MHResult a
k -> (MHResult a
k, forall a. MHResult a -> Trace a
trace MHResult a
k))) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (MHResult a)
mhTransWithBool Weighted (Density m) a
w) Trace a
initialValue
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a r. Functor m => Int -> Pipe a a m r
P.drop Int
numBurnIn
)