{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}

module Control.Monad.Bayes.Inference.Lazy.MH where

import Control.Monad.Bayes.Class (Log (ln))
import Control.Monad.Bayes.Sampler.Lazy
  ( SamplerT (runSamplerT),
    Tree (..),
    Trees (..),
    randomTree,
  )
import Control.Monad.Bayes.Weighted (WeightedT, runWeightedT)
import Control.Monad.Extra (iterateM)
import Control.Monad.State.Lazy (MonadState (get, put), runState)
import System.Random (RandomGen (split), getStdGen, newStdGen)
import System.Random qualified as R

mh :: forall a. Double -> WeightedT SamplerT a -> IO [(a, Log Double)]
mh :: forall a. Double -> WeightedT SamplerT a -> IO [(a, Log Double)]
mh Double
p WeightedT SamplerT a
m = do
  -- Top level: produce a stream of samples.
  -- Split the random number generator in two
  -- One part is used as the first seed for the simulation,
  -- and one part is used for the randomness in the MH algorithm.
  StdGen
g <- forall (m :: * -> *). MonadIO m => m StdGen
newStdGen forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
  let (StdGen
g1, StdGen
g2) = forall g. RandomGen g => g -> (g, g)
split StdGen
g
  let t :: Tree
t = forall g. RandomGen g => g -> Tree
randomTree StdGen
g1
  let (a
x, Log Double
w) = forall a. SamplerT a -> Tree -> a
runSamplerT (forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT WeightedT SamplerT a
m) Tree
t
  -- Now run step over and over to get a stream of (tree,result,weight)s.
  let ([(Tree, a, Log Double)]
samples, StdGen
_) = forall s a. State s a -> s -> (a, s)
runState (forall (m :: * -> *) a. Monad m => (a -> m a) -> a -> m [a]
iterateM forall {m :: * -> *} {s}.
(MonadState s m, RandomGen s) =>
(Tree, a, Log Double) -> m (Tree, a, Log Double)
step (Tree
t, a
x, Log Double
w)) StdGen
g2
  -- The stream of seeds is used to produce a stream of result/weight pairs.
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(Tree
_, a
x, Log Double
w) -> (a
x, Log Double
w)) [(Tree, a, Log Double)]
samples
  where
    --   where
    {- NB There are three kinds of randomness in the step function.
    1. The start tree 't', which is the source of randomness for simulating the
    program m to start with. This is sort-of the point in the "state space".
    2. The randomness needed to propose a new tree ('g1')
    3. The randomness needed to decide whether to accept or reject that ('g2')
    The tree t is an argument and result,
    but we use a state monad ('get'/'put') to deal with the other randomness '(g,g1,g2)' -}

    -- step :: RandomGen g => (Tree, a, Log Double) -> State g (Tree, a, Log Double)
    step :: (Tree, a, Log Double) -> m (Tree, a, Log Double)
step (Tree
t, a
x, Log Double
w) = do
      -- Randomly change some sites
      s
g <- forall s (m :: * -> *). MonadState s m => m s
get
      let (s
g1, s
g2) = forall g. RandomGen g => g -> (g, g)
split s
g
      let t' :: Tree
t' = forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p s
g1 Tree
t
      -- Rerun the model with the new tree, to get a new
      -- weight w'.
      let (a
x', Log Double
w') = forall a. SamplerT a -> Tree -> a
runSamplerT (forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT WeightedT SamplerT a
m) Tree
t'
      -- MH acceptance ratio. This is the probability of either
      -- returning the new seed or the old one.
      let ratio :: Log Double
ratio = Log Double
w' forall a. Fractional a => a -> a -> a
/ Log Double
w
      let (Double
r, s
g2') = forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random s
g2
      forall s (m :: * -> *). MonadState s m => s -> m ()
put s
g2'
      if Double
r forall a. Ord a => a -> a -> Bool
< forall a. Ord a => a -> a -> a
min Double
1 (forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln Log Double
ratio)
        then forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t', a
x', Log Double
w')
        else forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t, a
x, Log Double
w)

-- Replace the labels of a tree randomly, with probability p
mutateTree :: forall g. (RandomGen g) => Double -> g -> Tree -> Tree
mutateTree :: forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g (Tree Double
a Trees
ts) =
  let (Double
a', g
g') = (forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g :: (Double, g))
      (Double
a'', g
g'') = forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g'
   in Tree
        { currentUniform :: Double
currentUniform = if Double
a' forall a. Ord a => a -> a -> Bool
< Double
p then Double
a'' else Double
a,
          lazyUniforms :: Trees
lazyUniforms = forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g'' Trees
ts
        }

mutateTrees :: (RandomGen g) => Double -> g -> Trees -> Trees
mutateTrees :: forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g (Trees Tree
t Trees
ts) =
  let (g
g1, g
g2) = forall g. RandomGen g => g -> (g, g)
split g
g
   in Trees
        { headTree :: Tree
headTree = forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g1 Tree
t,
          tailTrees :: Trees
tailTrees = forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g2 Trees
ts
        }