{-# LANGUAGE GeneralizedNewtypeDeriving, ScopedTypeVariables #-}

{- | LazyPPL is a library for Bayesian probabilistic programming. It supports lazy use of probability, and we provide new Metropolis-Hastings simulation algorithms to allow this. Laziness appears to be a good paradigm for non-parametric statistics. 

Reference paper: [Affine Monads and Lazy Structures for Bayesian Programming](https://arxiv.org/abs/2212.07250). POPL 2023.

Illustrations: [https://lazyppl-team.github.io](https://lazyppl-team.github.io).


LazyPPL is inspired by recent ideas in synthetic probability theory and synthetic measure theory, such as [quasi-Borel spaces](https://ncatlab.org/nlab/show/quasi-Borel+space) and [Markov categories](https://ncatlab.org/nlab/show/Markov+category). LazyPPL is inspired by many other languages, including [Church](http://v1.probmods.org), [Anglican](https://probprog.github.io/anglican/), and [Monad-Bayes](https://hackage.haskell.org/package/monad-bayes). Monad-Bayes now includes a LazyPPL-inspired simulation algorithm.

This module defines

    1. Two monads: `Prob` (for probability measures) and `Meas` (for unnormalized measures), with interface `uniform`, `sample`, `score`. 

    2. Monte Carlo inference methods produce samples from an unnormalized measure. We provide three inference methods: 

        a. 'mh' (Metropolis-Hastings algorithm based on lazily mutating parts of the tree at random).

        b. 'mhirreducible', which randomly restarts for a properly irreducible Metropolis-Hastings kernel.

        c. 'wis' (simple reference weighted importance sampling)

        See also the SingleSite module for a separate single-site Metropolis-Hastings algorithm via GHC.Exts.Heap and System.IO.Unsafe.

    3. Various useful helpful functions.

    A typical usage would be

@   
    import LazyPPL (Prob, Meas, uniform, sample, score, mh, every)
@

    Most of the structure here will not be needed in typical models. We expose more of the structure for more experimental uses. 

The `Distributions` module provides many useful distributions, and further non-parametric distributions are in `Distributions.DirichletP`, `Distributions.GP`, `Distr.IBP`, and `Distr.Memoization`. 


-}

module LazyPPL
    ( -- * Rose tree type
      --
      -- | Our source of randomness will be an infinitely wide and deep lazy [rose tree](https://en.wikipedia.org/wiki/Rose_tree), regarded as initialized with uniform [0,1] choices for each label.
      Tree(Tree),
      -- * Monads
      Prob(Prob), Meas(Meas),
      -- * Basic interface
      --
      -- | There are three building blocks for measures: `uniform` for probability measures; `sample` and `score` for unnormalized measures. Combined with the monad structure, these give all s-finite measures.
      uniform, sample, score,
      -- * Monte Carlo simulation
      --
      -- | The `Meas` type describes unnormalized measures. Monte Carlo simulation allows us to sample from an unnormalized measure. Our main Monte Carlo simulator is `mh`. 
      mh, mhirreducible, weightedsamples, wis, 
      -- * Useful functions
      every, randomTree, runProb) where

import Control.Monad.Trans.Writer
import Control.Monad.Trans.Class
import Data.Monoid
import System.Random hiding (uniform)
import Control.Monad
import Control.Monad.Extra
import Control.Monad.State.Lazy (State, state , put, get, runState)
import Numeric.Log


{- | A `Tree` here is a lazy, infinitely wide and infinitely deep rose tree, labelled by Doubles.
-}
data Tree = Tree Double [Tree]
-- Often people would just use a list or stream instead of a tree.
-- But a tree allows us to be lazy about how far we are going all the time.

{- | A probability distribution over a is 
a function @ Tree -> a @.

We can think of this as the law of a random variable, indexed by the source of randomness, which is `Tree`. 

According to the monad implementation, a program uses up bits of the tree as it runs. The tree being infinitely wide and deep allows for lazy computation.-}
newtype Prob a = Prob (Tree -> a)

-- | Split tree splits a tree in two (bijectively)
splitTree :: Tree -> (Tree , Tree)
splitTree :: Tree -> (Tree, Tree)
splitTree (Tree Double
r (Tree
t : [Tree]
ts)) = (Tree
t , Double -> [Tree] -> Tree
Tree Double
r [Tree]
ts)


-- | Sequencing is done by splitting the tree
-- and using different bits for different computations.
-- 
-- This monad structure is strongly inspired by the probability monad of [quasi-Borel space](https://ncatlab.org/nlab/show/quasi-Borel+space#probability_distributions). 
instance Monad Prob where
  return :: forall a. a -> Prob a
return a
a = (Tree -> a) -> Prob a
forall a. (Tree -> a) -> Prob a
Prob ((Tree -> a) -> Prob a) -> (Tree -> a) -> Prob a
forall a b. (a -> b) -> a -> b
$ a -> Tree -> a
forall a b. a -> b -> a
const a
a
  (Prob Tree -> a
m) >>= :: forall a b. Prob a -> (a -> Prob b) -> Prob b
>>= a -> Prob b
f = (Tree -> b) -> Prob b
forall a. (Tree -> a) -> Prob a
Prob ((Tree -> b) -> Prob b) -> (Tree -> b) -> Prob b
forall a b. (a -> b) -> a -> b
$ \Tree
g ->
    let (Tree
g1, Tree
g2) = Tree -> (Tree, Tree)
splitTree Tree
g
        (Prob Tree -> b
m') = a -> Prob b
f (Tree -> a
m Tree
g1)
    in Tree -> b
m' Tree
g2
instance Functor Prob where fmap :: forall a b. (a -> b) -> Prob a -> Prob b
fmap = (a -> b) -> Prob a -> Prob b
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM
instance Applicative Prob where {pure :: forall a. a -> Prob a
pure = a -> Prob a
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return ; <*> :: forall a b. Prob (a -> b) -> Prob a -> Prob b
(<*>) = Prob (a -> b) -> Prob a -> Prob b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap}

{- | An unnormalized measure is represented by a probability distribution over pairs of a weight and a result. -}
newtype Meas a = Meas (WriterT (Product (Log Double)) Prob a)
  deriving((forall a b. (a -> b) -> Meas a -> Meas b)
-> (forall a b. a -> Meas b -> Meas a) -> Functor Meas
forall a b. a -> Meas b -> Meas a
forall a b. (a -> b) -> Meas a -> Meas b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> Meas a -> Meas b
fmap :: forall a b. (a -> b) -> Meas a -> Meas b
$c<$ :: forall a b. a -> Meas b -> Meas a
<$ :: forall a b. a -> Meas b -> Meas a
Functor, Functor Meas
Functor Meas =>
(forall a. a -> Meas a)
-> (forall a b. Meas (a -> b) -> Meas a -> Meas b)
-> (forall a b c. (a -> b -> c) -> Meas a -> Meas b -> Meas c)
-> (forall a b. Meas a -> Meas b -> Meas b)
-> (forall a b. Meas a -> Meas b -> Meas a)
-> Applicative Meas
forall a. a -> Meas a
forall a b. Meas a -> Meas b -> Meas a
forall a b. Meas a -> Meas b -> Meas b
forall a b. Meas (a -> b) -> Meas a -> Meas b
forall a b c. (a -> b -> c) -> Meas a -> Meas b -> Meas c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> Meas a
pure :: forall a. a -> Meas a
$c<*> :: forall a b. Meas (a -> b) -> Meas a -> Meas b
<*> :: forall a b. Meas (a -> b) -> Meas a -> Meas b
$cliftA2 :: forall a b c. (a -> b -> c) -> Meas a -> Meas b -> Meas c
liftA2 :: forall a b c. (a -> b -> c) -> Meas a -> Meas b -> Meas c
$c*> :: forall a b. Meas a -> Meas b -> Meas b
*> :: forall a b. Meas a -> Meas b -> Meas b
$c<* :: forall a b. Meas a -> Meas b -> Meas a
<* :: forall a b. Meas a -> Meas b -> Meas a
Applicative, Applicative Meas
Applicative Meas =>
(forall a b. Meas a -> (a -> Meas b) -> Meas b)
-> (forall a b. Meas a -> Meas b -> Meas b)
-> (forall a. a -> Meas a)
-> Monad Meas
forall a. a -> Meas a
forall a b. Meas a -> Meas b -> Meas b
forall a b. Meas a -> (a -> Meas b) -> Meas b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. Meas a -> (a -> Meas b) -> Meas b
>>= :: forall a b. Meas a -> (a -> Meas b) -> Meas b
$c>> :: forall a b. Meas a -> Meas b -> Meas b
>> :: forall a b. Meas a -> Meas b -> Meas b
$creturn :: forall a. a -> Meas a
return :: forall a. a -> Meas a
Monad)

{- | A uniform sample is a building block for probability distributions.

This is implemented by getting the label at the head of the tree and discarding the rest.-}
uniform :: Prob Double
uniform :: Prob Double
uniform = (Tree -> Double) -> Prob Double
forall a. (Tree -> a) -> Prob a
Prob ((Tree -> Double) -> Prob Double)
-> (Tree -> Double) -> Prob Double
forall a b. (a -> b) -> a -> b
$ \(Tree Double
r [Tree]
_) -> Double
r

-- | Regard a probability measure as an unnormalized measure.
sample :: Prob a -> Meas a
sample :: forall a. Prob a -> Meas a
sample Prob a
p = WriterT (Product (Log Double)) Prob a -> Meas a
forall a. WriterT (Product (Log Double)) Prob a -> Meas a
Meas (WriterT (Product (Log Double)) Prob a -> Meas a)
-> WriterT (Product (Log Double)) Prob a -> Meas a
forall a b. (a -> b) -> a -> b
$ Prob a -> WriterT (Product (Log Double)) Prob a
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT (Product (Log Double)) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Prob a
p

{- | A one point measure with a given score (or weight, or mass, or likelihood), which should be a positive real number.

A score of 0 describes impossibility. To avoid numeric issues, we encode it as @ exp(-300) @ instead.-}
score :: Double -> Meas ()
score :: Double -> Meas ()
score Double
r = WriterT (Product (Log Double)) Prob () -> Meas ()
forall a. WriterT (Product (Log Double)) Prob a -> Meas a
Meas (WriterT (Product (Log Double)) Prob () -> Meas ())
-> WriterT (Product (Log Double)) Prob () -> Meas ()
forall a b. (a -> b) -> a -> b
$ Product (Log Double) -> WriterT (Product (Log Double)) Prob ()
forall (m :: * -> *) w. Monad m => w -> WriterT w m ()
tell (Product (Log Double) -> WriterT (Product (Log Double)) Prob ())
-> Product (Log Double) -> WriterT (Product (Log Double)) Prob ()
forall a b. (a -> b) -> a -> b
$ Log Double -> Product (Log Double)
forall a. a -> Product a
Product (Log Double -> Product (Log Double))
-> Log Double -> Product (Log Double)
forall a b. (a -> b) -> a -> b
$ (Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double)
-> (Double -> Double) -> Double -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
log) (if Double
rDouble -> Double -> Bool
forall a. Eq a => a -> a -> Bool
==Double
0 then Double -> Double
forall a. Floating a => a -> a
exp(-Double
300) else Double
r)

scoreLog :: Log Double -> Meas ()
scoreLog :: Log Double -> Meas ()
scoreLog Log Double
r = WriterT (Product (Log Double)) Prob () -> Meas ()
forall a. WriterT (Product (Log Double)) Prob a -> Meas a
Meas (WriterT (Product (Log Double)) Prob () -> Meas ())
-> WriterT (Product (Log Double)) Prob () -> Meas ()
forall a b. (a -> b) -> a -> b
$ Product (Log Double) -> WriterT (Product (Log Double)) Prob ()
forall (m :: * -> *) w. Monad m => w -> WriterT w m ()
tell (Product (Log Double) -> WriterT (Product (Log Double)) Prob ())
-> Product (Log Double) -> WriterT (Product (Log Double)) Prob ()
forall a b. (a -> b) -> a -> b
$ Log Double -> Product (Log Double)
forall a. a -> Product a
Product Log Double
r

scoreProductLog :: Product (Log Double) -> Meas ()
scoreProductLog :: Product (Log Double) -> Meas ()
scoreProductLog Product (Log Double)
r = WriterT (Product (Log Double)) Prob () -> Meas ()
forall a. WriterT (Product (Log Double)) Prob a -> Meas a
Meas (WriterT (Product (Log Double)) Prob () -> Meas ())
-> WriterT (Product (Log Double)) Prob () -> Meas ()
forall a b. (a -> b) -> a -> b
$ Product (Log Double) -> WriterT (Product (Log Double)) Prob ()
forall (m :: * -> *) w. Monad m => w -> WriterT w m ()
tell Product (Log Double)
r


{- | Generate a tree with uniform random labels.

    This uses 'split' to split a random seed. -}
randomTree :: RandomGen g => g -> Tree
randomTree :: forall g. RandomGen g => g -> Tree
randomTree g
g = let (Double
a,g
g') = g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random g
g in Double -> [Tree] -> Tree
Tree Double
a (g -> [Tree]
forall g. RandomGen g => g -> [Tree]
randomTrees g
g')
randomTrees :: RandomGen g => g -> [Tree]
randomTrees :: forall g. RandomGen g => g -> [Tree]
randomTrees g
g = let (g
g1,g
g2) = g -> (g, g)
forall g. RandomGen g => g -> (g, g)
split g
g in g -> Tree
forall g. RandomGen g => g -> Tree
randomTree g
g1 Tree -> [Tree] -> [Tree]
forall a. a -> [a] -> [a]
: g -> [Tree]
forall g. RandomGen g => g -> [Tree]
randomTrees g
g2

{- | 'runProb' runs a probability deterministically, given a source of randomness. -}
runProb :: Prob a -> Tree -> a
runProb :: forall a. Prob a -> Tree -> a
runProb (Prob Tree -> a
a) = Tree -> a
a

{- | Runs an unnormalized measure and gets out a stream of (result,weight) pairs.

These are not samples from the renormalized distribution, just plain (result,weight) pairs. This is useful when the distribution is known to be normalized already. -}
weightedsamples :: forall a. Meas a -> IO [(a,Log Double)]
weightedsamples :: forall a. Meas a -> IO [(a, Log Double)]
weightedsamples (Meas WriterT (Product (Log Double)) Prob a
m) =
  do
    let helper :: Prob [(a, Product (Log Double))]
        helper :: Prob [(a, Product (Log Double))]
helper = do
          (a
x, Product (Log Double)
w) <- WriterT (Product (Log Double)) Prob a
-> Prob (a, Product (Log Double))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT (Product (Log Double)) Prob a
m
          [(a, Product (Log Double))]
rest <- Prob [(a, Product (Log Double))]
helper
          [(a, Product (Log Double))] -> Prob [(a, Product (Log Double))]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(a, Product (Log Double))] -> Prob [(a, Product (Log Double))])
-> [(a, Product (Log Double))] -> Prob [(a, Product (Log Double))]
forall a b. (a -> b) -> a -> b
$ (a
x, Product (Log Double)
w) (a, Product (Log Double))
-> [(a, Product (Log Double))] -> [(a, Product (Log Double))]
forall a. a -> [a] -> [a]
: [(a, Product (Log Double))]
rest
    IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
    StdGen
g <- IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
    let rs :: Tree
rs = StdGen -> Tree
forall g. RandomGen g => g -> Tree
randomTree StdGen
g
    let xws :: [(a, Product (Log Double))]
xws = Prob [(a, Product (Log Double))]
-> Tree -> [(a, Product (Log Double))]
forall a. Prob a -> Tree -> a
runProb Prob [(a, Product (Log Double))]
helper Tree
rs
    [(a, Log Double)] -> IO [(a, Log Double)]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(a, Log Double)] -> IO [(a, Log Double)])
-> [(a, Log Double)] -> IO [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ ((a, Product (Log Double)) -> (a, Log Double))
-> [(a, Product (Log Double))] -> [(a, Log Double)]
forall a b. (a -> b) -> [a] -> [b]
map (\(a
x,Product (Log Double)
w) -> (a
x, Product (Log Double) -> Log Double
forall a. Product a -> a
getProduct Product (Log Double)
w)) [(a, Product (Log Double))]
xws

{- | Weighted importance sampling first draws n weighted samples,
    and then samples a stream of results from that, regarded as an empirical distribution. Sometimes called "likelihood weighted importance sampling". 

This is a reference implementation. It will not usually be very efficient at all, but may be useful for debugging. 
 -}
wis :: Int -- ^ @n@, the number of samples to base on
    -> Meas a -- ^ @m@, the measure to normalize
    -> IO [a] -- ^ Returns a stream of samples
wis :: forall a. Int -> Meas a -> IO [a]
wis Int
n Meas a
m = do
  [(a, Log Double)]
xws <- Meas a -> IO [(a, Log Double)]
forall a. Meas a -> IO [(a, Log Double)]
weightedsamples Meas a
m
  let xws' :: [(a, Log Double)]
xws' = Int -> [(a, Log Double)] -> [(a, Log Double)]
forall a. Int -> [a] -> [a]
take Int
n ([(a, Log Double)] -> [(a, Log Double)])
-> [(a, Log Double)] -> [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ [(a, Log Double)] -> Log Double -> [(a, Log Double)]
forall {t} {a}. Num t => [(a, t)] -> t -> [(a, t)]
accumulate [(a, Log Double)]
xws Log Double
0
  let max :: Log Double
max = (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd ((a, Log Double) -> Log Double) -> (a, Log Double) -> Log Double
forall a b. (a -> b) -> a -> b
$ [(a, Log Double)] -> (a, Log Double)
forall a. HasCallStack => [a] -> a
last [(a, Log Double)]
xws'
  IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
  StdGen
g <- IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
  let rs :: [Double]
rs = (StdGen -> [Double]
forall g. RandomGen g => g -> [Double]
forall a g. (Random a, RandomGen g) => g -> [a]
randoms StdGen
g :: [Double])
  [a] -> IO [a]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([a] -> IO [a]) -> [a] -> IO [a]
forall a b. (a -> b) -> a -> b
$ (Double -> a) -> [Double] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (\Double
r -> (a, Log Double) -> a
forall a b. (a, b) -> a
fst ((a, Log Double) -> a) -> (a, Log Double) -> a
forall a b. (a -> b) -> a -> b
$ [(a, Log Double)] -> (a, Log Double)
forall a. HasCallStack => [a] -> a
head ([(a, Log Double)] -> (a, Log Double))
-> [(a, Log Double)] -> (a, Log Double)
forall a b. (a -> b) -> a -> b
$ ((a, Log Double) -> Bool) -> [(a, Log Double)] -> [(a, Log Double)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(a
x, Log Double
w) -> Log Double
w Log Double -> Log Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double -> Log Double
forall a. a -> Log a
Exp (Double -> Double
forall a. Floating a => a -> a
log Double
r) Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
max) [(a, Log Double)]
xws') [Double]
rs
  where accumulate :: [(a, t)] -> t -> [(a, t)]
accumulate ((a
x, t
w) : [(a, t)]
xws) t
a = (a
x, t
w t -> t -> t
forall a. Num a => a -> a -> a
+ t
a) (a, t) -> [(a, t)] -> [(a, t)]
forall a. a -> [a] -> [a]
: (a
x, t
w t -> t -> t
forall a. Num a => a -> a -> a
+ t
a) (a, t) -> [(a, t)] -> [(a, t)]
forall a. a -> [a] -> [a]
: [(a, t)] -> t -> [(a, t)]
accumulate [(a, t)]
xws (t
w t -> t -> t
forall a. Num a => a -> a -> a
+ t
a)
        accumulate [] t
a = []

{- | Produce a stream of samples, using Metropolis Hastings simulation.

   The weights are also returned. Often the weights can be discarded, but sometimes we may search for a sample of maximum score.

   The algorithm works as follows. 

   At each step, we randomly change some sites (nodes in the tree). 
   We then accept or reject these proposed changes, using a probability that is determined by the weight of the measure at the new tree. 
   If rejected, we repeat the previous sample. 

    This kernel is related to the one introduced by [Wingate, Stuhlmuller, Goodman, AISTATS 2011](http://proceedings.mlr.press/v15/wingate11a/wingate11a.pdf), but it is different in that it works when the number of sites is unknown. Moreover, since a site is a path through the tree, the address is more informative than a number, which avoids some addressing issues. 

    When 1/@p@ is roughly the number of used sites, then this will be a bit like "single-site lightweight" MH.
    If @p@ = 1 then this is "multi-site lightweight" MH.

    __Tip:__ if @m :: Prob a@ then use @map fst <$> (mh 1 $ sample m)@ to get a stream of samples from a probability distribution without conditioning. 
--}
{-- The algorithm is as follows:

    Top level: produces a stream of samples.
    * Split the random number generator in two
    * One part is used as the first seed for the simulation,
    * and one part is used for the randomness in the MH algorithm.

    Then, run 'step' over and over to get a stream of '(tree, result, weight)'s.
    The stream of seeds is used to produce a stream of result/weight pairs.

    NB There are three kinds of randomness in the step function.
    1. The start tree 't', which is the source of randomness for simulating the
    program m to start with. This is sort-of the point in the "state space".
    2. The randomness needed to propose a new tree ('g1')
    3. The randomness needed to decide whether to accept or reject that ('g2')
    The tree t is an argument and result,
    but we use a state monad ('get'/'put') to deal with the other randomness '(g, g1, g2)'

    Steps of the 'step' function:
    1. Randomly change some sites, 
    2. Rerun the model with the new tree, to get a new weight 'w''.
    3. Compute the MH acceptance ratio. This is the probability of either returning the new seed or the old one.
    4. Accept or reject the new sample based on the MH ratio.  
--}
mh :: forall a.
   Double -- ^ The chance @p@ of changing any site
   -> Meas a -- ^ The unnormalized measure to sample from
   -> IO [(a,Product (Log Double))] -- ^ Returns a stream of (result,weight) pairs
mh :: forall a. Double -> Meas a -> IO [(a, Product (Log Double))]
mh Double
p (Meas WriterT (Product (Log Double)) Prob a
m) = do
    -- Top level: produce a stream of samples.
    -- Split the random number generator in two
    -- One part is used as the first seed for the simulation,
    -- and one part is used for the randomness in the MH algorithm.
    IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
    StdGen
g <- IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
    let (StdGen
g1,StdGen
g2) = StdGen -> (StdGen, StdGen)
forall g. RandomGen g => g -> (g, g)
split StdGen
g
    let t :: Tree
t = StdGen -> Tree
forall g. RandomGen g => g -> Tree
randomTree StdGen
g1
    let (a
x, Product (Log Double)
w) = Prob (a, Product (Log Double)) -> Tree -> (a, Product (Log Double))
forall a. Prob a -> Tree -> a
runProb (WriterT (Product (Log Double)) Prob a
-> Prob (a, Product (Log Double))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT (Product (Log Double)) Prob a
m) Tree
t
    -- Now run step over and over to get a stream of (tree,result,weight)s.
    let ([(Tree, a, Product (Log Double))]
samples,StdGen
_) = State StdGen [(Tree, a, Product (Log Double))]
-> StdGen -> ([(Tree, a, Product (Log Double))], StdGen)
forall s a. State s a -> s -> (a, s)
runState (((Tree, a, Product (Log Double))
 -> StateT StdGen Identity (Tree, a, Product (Log Double)))
-> (Tree, a, Product (Log Double))
-> State StdGen [(Tree, a, Product (Log Double))]
forall (m :: * -> *) a. Monad m => (a -> m a) -> a -> m [a]
iterateM (Tree, a, Product (Log Double))
-> StateT StdGen Identity (Tree, a, Product (Log Double))
forall g.
RandomGen g =>
(Tree, a, Product (Log Double))
-> State g (Tree, a, Product (Log Double))
step (Tree
t,a
x,Product (Log Double)
w)) StdGen
g2
    -- The stream of seeds is used to produce a stream of result/weight pairs.
    [(a, Product (Log Double))] -> IO [(a, Product (Log Double))]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(a, Product (Log Double))] -> IO [(a, Product (Log Double))])
-> [(a, Product (Log Double))] -> IO [(a, Product (Log Double))]
forall a b. (a -> b) -> a -> b
$ ((Tree, a, Product (Log Double)) -> (a, Product (Log Double)))
-> [(Tree, a, Product (Log Double))] -> [(a, Product (Log Double))]
forall a b. (a -> b) -> [a] -> [b]
map (\(Tree
_,a
x,Product (Log Double)
w) -> (a
x,Product (Log Double)
w)) [(Tree, a, Product (Log Double))]
samples
    {- NB There are three kinds of randomness in the step function.
    1. The start tree 't', which is the source of randomness for simulating the
    program m to start with. This is sort-of the point in the "state space".
    2. The randomness needed to propose a new tree ('g1')
    3. The randomness needed to decide whether to accept or reject that ('g2')
    The tree t is an argument and result,
    but we use a state monad ('get'/'put') to deal with the other randomness '(g,g1,g2)' -}
    where step :: RandomGen g => (Tree,a,Product (Log Double)) -> State g (Tree,a,Product (Log Double))
          step :: forall g.
RandomGen g =>
(Tree, a, Product (Log Double))
-> State g (Tree, a, Product (Log Double))
step (Tree
t, a
x, Product (Log Double)
w) = do
            -- Randomly change some sites
            g
g <- StateT g Identity g
forall s (m :: * -> *). MonadState s m => m s
get
            let (g
g1, g
g2) = g -> (g, g)
forall g. RandomGen g => g -> (g, g)
split g
g
            let t' :: Tree
t' = Double -> g -> Tree -> Tree
forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g1 Tree
t
            -- Rerun the model with the new tree, to get a new
            -- weight w'.
            let (a
x', Product (Log Double)
w') = Prob (a, Product (Log Double)) -> Tree -> (a, Product (Log Double))
forall a. Prob a -> Tree -> a
runProb (WriterT (Product (Log Double)) Prob a
-> Prob (a, Product (Log Double))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT (Product (Log Double)) Prob a
m) Tree
t'
            -- MH acceptance ratio. This is the probability of either
            -- returning the new seed or the old one.
            let ratio :: Log Double
ratio = Product (Log Double) -> Log Double
forall a. Product a -> a
getProduct Product (Log Double)
w' Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Product (Log Double) -> Log Double
forall a. Product a -> a
getProduct Product (Log Double)
w
            let (Double
r, g
g2') = g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random g
g2
            g -> StateT g Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put g
g2'
            if Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
1 (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Double
forall a. Log a -> a
ln Log Double
ratio) -- (trace ("-- Ratio: " ++ show ratio) ratio))
              then (Tree, a, Product (Log Double))
-> State g (Tree, a, Product (Log Double))
forall a. a -> StateT g Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t', a
x', Product (Log Double)
w') -- trace ("---- Weight: " ++ show w') w')
              else (Tree, a, Product (Log Double))
-> State g (Tree, a, Product (Log Double))
forall a. a -> StateT g Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t, a
x, Product (Log Double)
w) -- trace ("---- Weight: " ++ show w) w)


-- | Replace the labels of a tree randomly, with probability p
mutateTree :: forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree :: forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g (Tree Double
a [Tree]
ts) =
  let (Double
a',g
g') = (g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random g
g :: (Double,g)) in
  let (Double
a'',g
g'') = g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random g
g' in
  Double -> [Tree] -> Tree
Tree (if Double
a'Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
p then Double
a'' else Double
a) (Double -> g -> [Tree] -> [Tree]
forall g. RandomGen g => Double -> g -> [Tree] -> [Tree]
mutateTrees Double
p g
g'' [Tree]
ts)
mutateTrees :: RandomGen g => Double -> g -> [Tree] -> [Tree]
mutateTrees :: forall g. RandomGen g => Double -> g -> [Tree] -> [Tree]
mutateTrees Double
p g
g (Tree
t:[Tree]
ts) = let (g
g1,g
g2) = g -> (g, g)
forall g. RandomGen g => g -> (g, g)
split g
g in Double -> g -> Tree -> Tree
forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g1 Tree
t Tree -> [Tree] -> [Tree]
forall a. a -> [a] -> [a]
: Double -> g -> [Tree] -> [Tree]
forall g. RandomGen g => Double -> g -> [Tree] -> [Tree]
mutateTrees Double
p g
g2 [Tree]
ts



{- | Irreducible form of 'mh'. Takes @p@ like 'mh', but also @q@, which is the chance of proposing an all-sites change. Irreducibility means that, asymptotically, the sequence of samples will converge in distribution to the renormalized version of @m@. 

The kernel in `mh` is not formally irreducible in the usual sense, although it is an open question whether this is a problem for asymptotic convergence in any definable model. In any case, convergence is only asymptotic, and so it can be helpful to use `mhirreducible` is that in some situations.

Roughly this avoids `mh` getting stuck in one particular mode, although it is a rather brutal method.

 -}

mhirreducible :: forall a.
              Double -- ^ The chance @p@ of changing any given site
              -> Double -- ^ The chance @q@ of doing an all-sites change
              -> Meas a -- ^ The unnormalized measure @m@ to sample from
              -> IO [(a,Product (Log Double))] -- ^ Returns a stream of (result,weight) pairs
mhirreducible :: forall a.
Double -> Double -> Meas a -> IO [(a, Product (Log Double))]
mhirreducible Double
p Double
q (Meas WriterT (Product (Log Double)) Prob a
m) = do
    -- Top level: produce a stream of samples.
    -- Split the random number generator in two
    -- One part is used as the first seed for the simulation,
    -- and one part is used for the randomness in the MH algorithm.
    IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
    StdGen
g <- IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
    let (StdGen
g1,StdGen
g2) = StdGen -> (StdGen, StdGen)
forall g. RandomGen g => g -> (g, g)
split StdGen
g
    let t :: Tree
t = StdGen -> Tree
forall g. RandomGen g => g -> Tree
randomTree StdGen
g1
    let (a
x, Product (Log Double)
w) = Prob (a, Product (Log Double)) -> Tree -> (a, Product (Log Double))
forall a. Prob a -> Tree -> a
runProb (WriterT (Product (Log Double)) Prob a
-> Prob (a, Product (Log Double))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT (Product (Log Double)) Prob a
m) Tree
t
    -- Now run step over and over to get a stream of (tree,result,weight)s.
    let ([(Tree, a, Product (Log Double))]
samples,StdGen
_) = State StdGen [(Tree, a, Product (Log Double))]
-> StdGen -> ([(Tree, a, Product (Log Double))], StdGen)
forall s a. State s a -> s -> (a, s)
runState (((Tree, a, Product (Log Double))
 -> StateT StdGen Identity (Tree, a, Product (Log Double)))
-> (Tree, a, Product (Log Double))
-> State StdGen [(Tree, a, Product (Log Double))]
forall (m :: * -> *) a. Monad m => (a -> m a) -> a -> m [a]
iterateM (Tree, a, Product (Log Double))
-> StateT StdGen Identity (Tree, a, Product (Log Double))
forall g.
RandomGen g =>
(Tree, a, Product (Log Double))
-> State g (Tree, a, Product (Log Double))
step (Tree
t,a
x,Product (Log Double)
w)) StdGen
g2
    -- The stream of seeds is used to produce a stream of result/weight pairs.
    [(a, Product (Log Double))] -> IO [(a, Product (Log Double))]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(a, Product (Log Double))] -> IO [(a, Product (Log Double))])
-> [(a, Product (Log Double))] -> IO [(a, Product (Log Double))]
forall a b. (a -> b) -> a -> b
$ ((Tree, a, Product (Log Double)) -> (a, Product (Log Double)))
-> [(Tree, a, Product (Log Double))] -> [(a, Product (Log Double))]
forall a b. (a -> b) -> [a] -> [b]
map (\(Tree
t,a
x,Product (Log Double)
w) -> (a
x,Product (Log Double)
w)) [(Tree, a, Product (Log Double))]
samples
    {- NB There are three kinds of randomness in the step function.
    1. The start tree 't', which is the source of randomness for simulating the
    program m to start with. This is sort-of the point in the "state space".
    2. The randomness needed to propose a new tree ('g1')
    3. The randomness needed to decide whether to accept or reject that ('g2')
    The tree t is an argument and result,
    but we use a state monad ('get'/'put') to deal with the other randomness '(g,g1,g2)' -}
    where step :: RandomGen g => (Tree,a,Product (Log Double)) -> State g (Tree,a,Product (Log Double))
          step :: forall g.
RandomGen g =>
(Tree, a, Product (Log Double))
-> State g (Tree, a, Product (Log Double))
step (Tree
t,a
x,Product (Log Double)
w) = do
            -- Randomly change some sites
            g
g <- StateT g Identity g
forall s (m :: * -> *). MonadState s m => m s
get
            let (g
g1,g
g2) = g -> (g, g)
forall g. RandomGen g => g -> (g, g)
split g
g
            -- Decide whether to resample all sites (r<q) or just some of them
            let (Double
r,g
g1') = g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random g
g1
            let t' :: Tree
t' = if Double
rDouble -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
q then g -> Tree
forall g. RandomGen g => g -> Tree
randomTree g
g1' else Double -> g -> Tree -> Tree
forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g1' Tree
t
            -- Rerun the model with the new tree, to get a new
            -- weight w'.
            let (a
x', Product (Log Double)
w') = Prob (a, Product (Log Double)) -> Tree -> (a, Product (Log Double))
forall a. Prob a -> Tree -> a
runProb (WriterT (Product (Log Double)) Prob a
-> Prob (a, Product (Log Double))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT (Product (Log Double)) Prob a
m) Tree
t'
            -- MH acceptance ratio. This is the probability of either
            -- returning the new seed or the old one.
            let ratio :: Log Double
ratio = Product (Log Double) -> Log Double
forall a. Product a -> a
getProduct Product (Log Double)
w' Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Product (Log Double) -> Log Double
forall a. Product a -> a
getProduct Product (Log Double)
w
            let (Double
r,g
g2') = g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random g
g2
            g -> StateT g Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put g
g2'
            if Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
1 (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Double
forall a. Log a -> a
ln Log Double
ratio) then (Tree, a, Product (Log Double))
-> State g (Tree, a, Product (Log Double))
forall a. a -> StateT g Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t',a
x',Product (Log Double)
w') else (Tree, a, Product (Log Double))
-> State g (Tree, a, Product (Log Double))
forall a. a -> StateT g Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t,a
x,Product (Log Double)
w)



{- | Useful function which thins out a stream of results, as is common in Markov Chain Monte Carlo simulation.

@every n xs@ returns only the elements at indices that are multiples of n.--}
every :: Int -> [a] -> [a]
every :: forall a. Int -> [a] -> [a]
every Int
n [a]
xs = case Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [a]
xs of
  (a
y : [a]
ys) -> a
y a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
every Int
n [a]
ys
  [] -> []