{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
module Control.Monad.Bayes.Inference.Lazy.MH where
import Control.Monad.Bayes.Class (Log (ln))
import Control.Monad.Bayes.Sampler.Lazy
( Sampler (runSampler),
Tree (..),
Trees (..),
randomTree,
)
import Control.Monad.Bayes.Weighted (Weighted, weighted)
import Control.Monad.Extra (iterateM)
import Control.Monad.State.Lazy (MonadState (get, put), runState)
import System.Random (RandomGen (split), getStdGen, newStdGen)
import System.Random qualified as R
mh :: forall a. Double -> Weighted Sampler a -> IO [(a, Log Double)]
mh :: forall a. Double -> Weighted Sampler a -> IO [(a, Log Double)]
mh Double
p Weighted Sampler a
m = do
StdGen
g <- forall (m :: * -> *). MonadIO m => m StdGen
newStdGen forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
let (StdGen
g1, StdGen
g2) = forall g. RandomGen g => g -> (g, g)
split StdGen
g
let t :: Tree
t = forall g. RandomGen g => g -> Tree
randomTree StdGen
g1
let (a
x, Log Double
w) = forall a. Sampler a -> Tree -> a
runSampler (forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted Weighted Sampler a
m) Tree
t
let ([(Tree, a, Log Double)]
samples, StdGen
_) = forall s a. State s a -> s -> (a, s)
runState (forall (m :: * -> *) a. Monad m => (a -> m a) -> a -> m [a]
iterateM forall {m :: * -> *} {s}.
(MonadState s m, RandomGen s) =>
(Tree, a, Log Double) -> m (Tree, a, Log Double)
step (Tree
t, a
x, Log Double
w)) StdGen
g2
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(Tree
_, a
x, Log Double
w) -> (a
x, Log Double
w)) [(Tree, a, Log Double)]
samples
where
step :: (Tree, a, Log Double) -> m (Tree, a, Log Double)
step (Tree
t, a
x, Log Double
w) = do
s
g <- forall s (m :: * -> *). MonadState s m => m s
get
let (s
g1, s
g2) = forall g. RandomGen g => g -> (g, g)
split s
g
let t' :: Tree
t' = forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p s
g1 Tree
t
let (a
x', Log Double
w') = forall a. Sampler a -> Tree -> a
runSampler (forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted Weighted Sampler a
m) Tree
t'
let ratio :: Log Double
ratio = Log Double
w' forall a. Fractional a => a -> a -> a
/ Log Double
w
let (Double
r, s
g2') = forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random s
g2
forall s (m :: * -> *). MonadState s m => s -> m ()
put s
g2'
if Double
r forall a. Ord a => a -> a -> Bool
< forall a. Ord a => a -> a -> a
min Double
1 (forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln Log Double
ratio)
then forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t', a
x', Log Double
w')
else forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t, a
x, Log Double
w)
mutateTree :: forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree :: forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g (Tree Double
a Trees
ts) =
let (Double
a', g
g') = (forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g :: (Double, g))
(Double
a'', g
g'') = forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g'
in Tree
{ currentUniform :: Double
currentUniform = if Double
a' forall a. Ord a => a -> a -> Bool
< Double
p then Double
a'' else Double
a,
lazyUniforms :: Trees
lazyUniforms = forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g'' Trees
ts
}
mutateTrees :: RandomGen g => Double -> g -> Trees -> Trees
mutateTrees :: forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g (Trees Tree
t Trees
ts) =
let (g
g1, g
g2) = forall g. RandomGen g => g -> (g, g)
split g
g
in Trees
{ headTree :: Tree
headTree = forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g1 Tree
t,
tailTrees :: Trees
tailTrees = forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g2 Trees
ts
}