{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}

-- |
-- Module      : Control.Monad.Bayes.Inference.MCMC
-- Description : Markov Chain Monte Carlo (MCMC)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : tweag.io
-- Stability   : experimental
-- Portability : GHC
module Control.Monad.Bayes.Inference.MCMC where

import Control.Monad.Bayes.Class (MonadDistribution)
import Control.Monad.Bayes.Traced.Basic qualified as Basic
import Control.Monad.Bayes.Traced.Common
  ( MHResult (MHResult, trace),
    Trace (probDensity),
    burnIn,
    mhTransWithBool,
  )
import Control.Monad.Bayes.Traced.Dynamic qualified as Dynamic
import Control.Monad.Bayes.Traced.Static qualified as Static
import Control.Monad.Bayes.Weighted (WeightedT, unweighted)
import Pipes ((>->))
import Pipes qualified as P
import Pipes.Prelude qualified as P

data Proposal = SingleSiteMH

data MCMCConfig = MCMCConfig {MCMCConfig -> Proposal
proposal :: Proposal, MCMCConfig -> Int
numMCMCSteps :: Int, MCMCConfig -> Int
numBurnIn :: Int}

defaultMCMCConfig :: MCMCConfig
defaultMCMCConfig :: MCMCConfig
defaultMCMCConfig = MCMCConfig {proposal :: Proposal
proposal = Proposal
SingleSiteMH, numMCMCSteps :: Int
numMCMCSteps = Int
1, numBurnIn :: Int
numBurnIn = Int
0}

mcmc :: (MonadDistribution m) => MCMCConfig -> Static.TracedT (WeightedT m) a -> m [a]
mcmc :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> TracedT (WeightedT m) a -> m [a]
mcmc (MCMCConfig {Int
Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
..}) TracedT (WeightedT m) a
m = forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Functor m => WeightedT m a -> m a
unweighted forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadDistribution m =>
Int -> TracedT m a -> m [a]
Static.mh Int
numMCMCSteps TracedT (WeightedT m) a
m

mcmcBasic :: (MonadDistribution m) => MCMCConfig -> Basic.TracedT (WeightedT m) a -> m [a]
mcmcBasic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> TracedT (WeightedT m) a -> m [a]
mcmcBasic (MCMCConfig {Int
Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
..}) TracedT (WeightedT m) a
m = forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Functor m => WeightedT m a -> m a
unweighted forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadDistribution m =>
Int -> TracedT m a -> m [a]
Basic.mh Int
numMCMCSteps TracedT (WeightedT m) a
m

mcmcDynamic :: (MonadDistribution m) => MCMCConfig -> Dynamic.TracedT (WeightedT m) a -> m [a]
mcmcDynamic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> TracedT (WeightedT m) a -> m [a]
mcmcDynamic (MCMCConfig {Int
Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
..}) TracedT (WeightedT m) a
m = forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Functor m => WeightedT m a -> m a
unweighted forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadDistribution m =>
Int -> TracedT m a -> m [a]
Dynamic.mh Int
numMCMCSteps TracedT (WeightedT m) a
m

-- -- | draw iid samples until you get one that has non-zero likelihood
independentSamples :: (Monad m) => Static.TracedT m a -> P.Producer (MHResult a) m (Trace a)
independentSamples :: forall (m :: * -> *) a.
Monad m =>
TracedT m a -> Producer (MHResult a) m (Trace a)
independentSamples (Static.TracedT WeightedT (DensityT m) a
_w m (Trace a)
d) =
  forall (m :: * -> *) a x' x r.
Monad m =>
m a -> Proxy x' x () a m r
P.repeatM m (Trace a)
d
    forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a. Functor m => (a -> Bool) -> Pipe a a m a
P.takeWhile' ((forall a. Eq a => a -> a -> Bool
== Log Double
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Trace a -> Log Double
probDensity)
    forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a b r. Functor m => (a -> b) -> Pipe a b m r
P.map (forall a. Bool -> Trace a -> MHResult a
MHResult Bool
False)

-- | convert a probabilistic program into a producer of samples
mcmcP :: (MonadDistribution m) => MCMCConfig -> Static.TracedT m a -> P.Producer (MHResult a) m ()
mcmcP :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> TracedT m a -> Producer (MHResult a) m ()
mcmcP MCMCConfig {Int
Proposal
numBurnIn :: Int
numMCMCSteps :: Int
proposal :: Proposal
numBurnIn :: MCMCConfig -> Int
numMCMCSteps :: MCMCConfig -> Int
proposal :: MCMCConfig -> Proposal
..} m :: TracedT m a
m@(Static.TracedT WeightedT (DensityT m) a
w m (Trace a)
_) = do
  Trace a
initialValue <- forall (m :: * -> *) a.
Monad m =>
TracedT m a -> Producer (MHResult a) m (Trace a)
independentSamples TracedT m a
m forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a r. Functor m => Consumer' a m r
P.drain
  ( forall (m :: * -> *) s r a.
Monad m =>
(s -> m (Either r (a, s))) -> s -> Producer a m r
P.unfoldr (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\MHResult a
k -> (MHResult a
k, forall a. MHResult a -> Trace a
trace MHResult a
k))) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadDistribution m =>
WeightedT (DensityT m) a -> Trace a -> m (MHResult a)
mhTransWithBool WeightedT (DensityT m) a
w) Trace a
initialValue
      forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a r. Functor m => Int -> Pipe a a m r
P.drop Int
numBurnIn
    )