{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
module Control.Monad.Bayes.Inference.MCMC where
import Control.Monad.Bayes.Class (MonadDistribution)
import Control.Monad.Bayes.Traced.Basic qualified as Basic
import Control.Monad.Bayes.Traced.Common
( MHResult (MHResult, trace),
Trace (probDensity),
burnIn,
mhTransWithBool,
)
import Control.Monad.Bayes.Traced.Dynamic qualified as Dynamic
import Control.Monad.Bayes.Traced.Static qualified as Static
import Control.Monad.Bayes.Weighted (Weighted, unweighted)
import Pipes ((>->))
import Pipes qualified as P
import Pipes.Prelude qualified as P
data Proposal = SingleSiteMH
data MCMCConfig = MCMCConfig {MCMCConfig -> Proposal
proposal :: Proposal, MCMCConfig -> Int
numMCMCSteps :: Int, MCMCConfig -> Int
numBurnIn :: Int}
defaultMCMCConfig :: MCMCConfig
defaultMCMCConfig :: MCMCConfig
defaultMCMCConfig = MCMCConfig {proposal :: Proposal
proposal = Proposal
SingleSiteMH, numMCMCSteps :: Int
numMCMCSteps = Int
1, numBurnIn :: Int
numBurnIn = Int
0}
mcmc :: MonadDistribution m => MCMCConfig -> Static.Traced (Weighted m) a -> m [a]
mcmc :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced (Weighted m) a -> m [a]
mcmc (MCMCConfig {Int
Proposal
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
..}) Traced (Weighted m) a
m = Int -> m [a] -> m [a]
forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn (m [a] -> m [a]) -> m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ Weighted m [a] -> m [a]
forall (m :: * -> *) a. Functor m => Weighted m a -> m a
unweighted (Weighted m [a] -> m [a]) -> Weighted m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ Int -> Traced (Weighted m) a -> Weighted m [a]
forall (m :: * -> *) a.
MonadDistribution m =>
Int -> Traced m a -> m [a]
Static.mh Int
numMCMCSteps Traced (Weighted m) a
m
mcmcBasic :: MonadDistribution m => MCMCConfig -> Basic.Traced (Weighted m) a -> m [a]
mcmcBasic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced (Weighted m) a -> m [a]
mcmcBasic (MCMCConfig {Int
Proposal
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
..}) Traced (Weighted m) a
m = Int -> m [a] -> m [a]
forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn (m [a] -> m [a]) -> m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ Weighted m [a] -> m [a]
forall (m :: * -> *) a. Functor m => Weighted m a -> m a
unweighted (Weighted m [a] -> m [a]) -> Weighted m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ Int -> Traced (Weighted m) a -> Weighted m [a]
forall (m :: * -> *) a.
MonadDistribution m =>
Int -> Traced m a -> m [a]
Basic.mh Int
numMCMCSteps Traced (Weighted m) a
m
mcmcDynamic :: MonadDistribution m => MCMCConfig -> Dynamic.Traced (Weighted m) a -> m [a]
mcmcDynamic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced (Weighted m) a -> m [a]
mcmcDynamic (MCMCConfig {Int
Proposal
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
..}) Traced (Weighted m) a
m = Int -> m [a] -> m [a]
forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn (m [a] -> m [a]) -> m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ Weighted m [a] -> m [a]
forall (m :: * -> *) a. Functor m => Weighted m a -> m a
unweighted (Weighted m [a] -> m [a]) -> Weighted m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ Int -> Traced (Weighted m) a -> Weighted m [a]
forall (m :: * -> *) a.
MonadDistribution m =>
Int -> Traced m a -> m [a]
Dynamic.mh Int
numMCMCSteps Traced (Weighted m) a
m
independentSamples :: Monad m => Static.Traced m a -> P.Producer (MHResult a) m (Trace a)
independentSamples :: forall (m :: * -> *) a.
Monad m =>
Traced m a -> Producer (MHResult a) m (Trace a)
independentSamples (Static.Traced Weighted (Density m) a
_w m (Trace a)
d) =
m (Trace a) -> Proxy X () () (Trace a) m (Trace a)
forall (m :: * -> *) a x' x r.
Monad m =>
m a -> Proxy x' x () a m r
P.repeatM m (Trace a)
d
Proxy X () () (Trace a) m (Trace a)
-> Proxy () (Trace a) () (Trace a) m (Trace a)
-> Proxy X () () (Trace a) m (Trace a)
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Trace a -> Bool) -> Proxy () (Trace a) () (Trace a) m (Trace a)
forall (m :: * -> *) a. Functor m => (a -> Bool) -> Pipe a a m a
P.takeWhile' ((Log Double -> Log Double -> Bool
forall a. Eq a => a -> a -> Bool
== Log Double
0) (Log Double -> Bool) -> (Trace a -> Log Double) -> Trace a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace a -> Log Double
forall a. Trace a -> Log Double
probDensity)
Proxy X () () (Trace a) m (Trace a)
-> Proxy () (Trace a) () (MHResult a) m (Trace a)
-> Proxy X () () (MHResult a) m (Trace a)
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Trace a -> MHResult a)
-> Proxy () (Trace a) () (MHResult a) m (Trace a)
forall (m :: * -> *) a b r. Functor m => (a -> b) -> Pipe a b m r
P.map (Bool -> Trace a -> MHResult a
forall a. Bool -> Trace a -> MHResult a
MHResult Bool
False)
mcmcP :: MonadDistribution m => MCMCConfig -> Static.Traced m a -> P.Producer (MHResult a) m ()
mcmcP :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced m a -> Producer (MHResult a) m ()
mcmcP MCMCConfig {Int
Proposal
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
..} m :: Traced m a
m@(Static.Traced Weighted (Density m) a
w m (Trace a)
_) = do
Trace a
initialValue <- Traced m a -> Producer (MHResult a) m (Trace a)
forall (m :: * -> *) a.
Monad m =>
Traced m a -> Producer (MHResult a) m (Trace a)
independentSamples Traced m a
m Producer (MHResult a) m (Trace a)
-> Proxy () (MHResult a) () (MHResult a) m (Trace a)
-> Producer (MHResult a) m (Trace a)
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Proxy () (MHResult a) () (MHResult a) m (Trace a)
Consumer' (MHResult a) m (Trace a)
forall (m :: * -> *) a r. Functor m => Consumer' a m r
P.drain
( (Trace a -> m (Either () (MHResult a, Trace a)))
-> Trace a -> Producer (MHResult a) m ()
forall (m :: * -> *) s r a.
Monad m =>
(s -> m (Either r (a, s))) -> s -> Producer a m r
P.unfoldr ((MHResult a -> Either () (MHResult a, Trace a))
-> m (MHResult a) -> m (Either () (MHResult a, Trace a))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((MHResult a, Trace a) -> Either () (MHResult a, Trace a)
forall a b. b -> Either a b
Right ((MHResult a, Trace a) -> Either () (MHResult a, Trace a))
-> (MHResult a -> (MHResult a, Trace a))
-> MHResult a
-> Either () (MHResult a, Trace a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\MHResult a
k -> (MHResult a
k, MHResult a -> Trace a
forall a. MHResult a -> Trace a
trace MHResult a
k))) (m (MHResult a) -> m (Either () (MHResult a, Trace a)))
-> (Trace a -> m (MHResult a))
-> Trace a
-> m (Either () (MHResult a, Trace a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Weighted (Density m) a -> Trace a -> m (MHResult a)
forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (MHResult a)
mhTransWithBool Weighted (Density m) a
w) Trace a
initialValue
Producer (MHResult a) m ()
-> Proxy () (MHResult a) () (MHResult a) m ()
-> Producer (MHResult a) m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (MHResult a) () (MHResult a) m ()
forall (m :: * -> *) a r. Functor m => Int -> Pipe a a m r
P.drop Int
numBurnIn
)