{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | This is a port of the implementation of LazyPPL: https://lazyppl.bitbucket.io/
module Control.Monad.Bayes.Sampler.Lazy where

import Control.Monad (ap)
import Control.Monad.Bayes.Class (MonadDistribution (random))
import Control.Monad.Bayes.Weighted (Weighted, weighted)
import Numeric.Log (Log (..))
import System.Random
  ( RandomGen (split),
    getStdGen,
    newStdGen,
  )
import System.Random qualified as R

-- | A 'Tree' is a lazy, infinitely wide and infinitely deep tree, labelled by Doubles
-- | Our source of randomness will be a Tree, populated by uniform [0,1] choices for each label.
-- | Often people just use a list or stream instead of a tree.
-- | But a tree allows us to be lazy about how far we are going all the time.
data Tree = Tree
  { Tree -> Double
currentUniform :: Double,
    Tree -> Trees
lazyUniforms :: Trees
  }

-- | An infinite stream of 'Tree's.
data Trees = Trees
  { Trees -> Tree
headTree :: Tree,
    Trees -> Trees
tailTrees :: Trees
  }

-- | A probability distribution over a is
-- | a function 'Tree -> a'
-- | The idea is that it uses up bits of the tree as it runs
newtype Sampler a = Sampler {forall a. Sampler a -> Tree -> a
runSampler :: Tree -> a}
  deriving ((forall a b. (a -> b) -> Sampler a -> Sampler b)
-> (forall a b. a -> Sampler b -> Sampler a) -> Functor Sampler
forall a b. a -> Sampler b -> Sampler a
forall a b. (a -> b) -> Sampler a -> Sampler b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> Sampler a -> Sampler b
fmap :: forall a b. (a -> b) -> Sampler a -> Sampler b
$c<$ :: forall a b. a -> Sampler b -> Sampler a
<$ :: forall a b. a -> Sampler b -> Sampler a
Functor)

-- | Two key things to do with trees:
-- | Split tree splits a tree in two (bijectively)
-- | Get the label at the head of the tree and discard the rest
splitTree :: Tree -> (Tree, Tree)
splitTree :: Tree -> (Tree, Tree)
splitTree (Tree Double
r (Trees Tree
t Trees
ts)) = (Tree
t, Double -> Trees -> Tree
Tree Double
r Trees
ts)

-- | Preliminaries for the simulation methods. Generate a tree with uniform random labels. This uses 'split' to split a random seed
randomTree :: RandomGen g => g -> Tree
randomTree :: forall g. RandomGen g => g -> Tree
randomTree g
g = let (Double
a, g
g') = g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g in Double -> Trees -> Tree
Tree Double
a (g -> Trees
forall g. RandomGen g => g -> Trees
randomTrees g
g')

randomTrees :: RandomGen g => g -> Trees
randomTrees :: forall g. RandomGen g => g -> Trees
randomTrees g
g = let (g
g1, g
g2) = g -> (g, g)
forall g. RandomGen g => g -> (g, g)
split g
g in Tree -> Trees -> Trees
Trees (g -> Tree
forall g. RandomGen g => g -> Tree
randomTree g
g1) (g -> Trees
forall g. RandomGen g => g -> Trees
randomTrees g
g2)

instance Applicative Sampler where
  pure :: forall a. a -> Sampler a
pure = (Tree -> a) -> Sampler a
forall a. (Tree -> a) -> Sampler a
Sampler ((Tree -> a) -> Sampler a) -> (a -> Tree -> a) -> a -> Sampler a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Tree -> a
forall a b. a -> b -> a
const
  <*> :: forall a b. Sampler (a -> b) -> Sampler a -> Sampler b
(<*>) = Sampler (a -> b) -> Sampler a -> Sampler b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

-- | probabilities for a monad.
-- | Sequencing is done by splitting the tree
-- | and using different bits for different computations.
instance Monad Sampler where
  return :: forall a. a -> Sampler a
return = a -> Sampler a
forall a. a -> Sampler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  (Sampler Tree -> a
m) >>= :: forall a b. Sampler a -> (a -> Sampler b) -> Sampler b
>>= a -> Sampler b
f = (Tree -> b) -> Sampler b
forall a. (Tree -> a) -> Sampler a
Sampler \Tree
g ->
    let (Tree
g1, Tree
g2) = Tree -> (Tree, Tree)
splitTree Tree
g
        (Sampler Tree -> b
m') = a -> Sampler b
f (Tree -> a
m Tree
g1)
     in Tree -> b
m' Tree
g2

instance MonadDistribution Sampler where
  random :: Sampler Double
random = (Tree -> Double) -> Sampler Double
forall a. (Tree -> a) -> Sampler a
Sampler \(Tree Double
r Trees
_) -> Double
r

sampler :: Sampler a -> IO a
sampler :: forall a. Sampler a -> IO a
sampler Sampler a
m = IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen IO StdGen -> IO a -> IO a
forall a b. IO a -> IO b -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Sampler a -> Tree -> a
forall a. Sampler a -> Tree -> a
runSampler Sampler a
m (Tree -> a) -> (StdGen -> Tree) -> StdGen -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StdGen -> Tree
forall g. RandomGen g => g -> Tree
randomTree (StdGen -> a) -> IO StdGen -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen)

independent :: Monad m => m a -> m [a]
independent :: forall (m :: * -> *) a. Monad m => m a -> m [a]
independent = [m a] -> m [a]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([m a] -> m [a]) -> (m a -> [m a]) -> m a -> m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> [m a]
forall a. a -> [a]
repeat

-- | 'weightedsamples' runs a probability measure and gets out a stream of (result,weight) pairs
weightedsamples :: Weighted Sampler a -> IO [(a, Log Double)]
weightedsamples :: forall a. Weighted Sampler a -> IO [(a, Log Double)]
weightedsamples = Sampler [(a, Log Double)] -> IO [(a, Log Double)]
forall a. Sampler a -> IO a
sampler (Sampler [(a, Log Double)] -> IO [(a, Log Double)])
-> (Weighted Sampler a -> Sampler [(a, Log Double)])
-> Weighted Sampler a
-> IO [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sampler (a, Log Double) -> Sampler [(a, Log Double)]
forall (m :: * -> *) a. Monad m => m a -> m [a]
independent (Sampler (a, Log Double) -> Sampler [(a, Log Double)])
-> (Weighted Sampler a -> Sampler (a, Log Double))
-> Weighted Sampler a
-> Sampler [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Weighted Sampler a -> Sampler (a, Log Double)
forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted