{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Monad.Bayes.Sampler.Lazy where
import Control.Monad (ap)
import Control.Monad.Bayes.Class (MonadDistribution (random))
import Control.Monad.Bayes.Weighted (Weighted, weighted)
import Numeric.Log (Log (..))
import System.Random
( RandomGen (split),
getStdGen,
newStdGen,
)
import System.Random qualified as R
data Tree = Tree
{ Tree -> Double
currentUniform :: Double,
Tree -> Trees
lazyUniforms :: Trees
}
data Trees = Trees
{ Trees -> Tree
headTree :: Tree,
Trees -> Trees
tailTrees :: Trees
}
newtype Sampler a = Sampler {forall a. Sampler a -> Tree -> a
runSampler :: Tree -> a}
deriving (forall a b. a -> Sampler b -> Sampler a
forall a b. (a -> b) -> Sampler a -> Sampler b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> Sampler b -> Sampler a
$c<$ :: forall a b. a -> Sampler b -> Sampler a
fmap :: forall a b. (a -> b) -> Sampler a -> Sampler b
$cfmap :: forall a b. (a -> b) -> Sampler a -> Sampler b
Functor)
splitTree :: Tree -> (Tree, Tree)
splitTree :: Tree -> (Tree, Tree)
splitTree (Tree Double
r (Trees Tree
t Trees
ts)) = (Tree
t, Double -> Trees -> Tree
Tree Double
r Trees
ts)
randomTree :: RandomGen g => g -> Tree
randomTree :: forall g. RandomGen g => g -> Tree
randomTree g
g = let (Double
a, g
g') = forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g in Double -> Trees -> Tree
Tree Double
a (forall g. RandomGen g => g -> Trees
randomTrees g
g')
randomTrees :: RandomGen g => g -> Trees
randomTrees :: forall g. RandomGen g => g -> Trees
randomTrees g
g = let (g
g1, g
g2) = forall g. RandomGen g => g -> (g, g)
split g
g in Tree -> Trees -> Trees
Trees (forall g. RandomGen g => g -> Tree
randomTree g
g1) (forall g. RandomGen g => g -> Trees
randomTrees g
g2)
instance Applicative Sampler where
pure :: forall a. a -> Sampler a
pure = forall a. (Tree -> a) -> Sampler a
Sampler forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const
<*> :: forall a b. Sampler (a -> b) -> Sampler a -> Sampler b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
instance Monad Sampler where
return :: forall a. a -> Sampler a
return = forall (f :: * -> *) a. Applicative f => a -> f a
pure
(Sampler Tree -> a
m) >>= :: forall a b. Sampler a -> (a -> Sampler b) -> Sampler b
>>= a -> Sampler b
f = forall a. (Tree -> a) -> Sampler a
Sampler \Tree
g ->
let (Tree
g1, Tree
g2) = Tree -> (Tree, Tree)
splitTree Tree
g
(Sampler Tree -> b
m') = a -> Sampler b
f (Tree -> a
m Tree
g1)
in Tree -> b
m' Tree
g2
instance MonadDistribution Sampler where
random :: Sampler Double
random = forall a. (Tree -> a) -> Sampler a
Sampler \(Tree Double
r Trees
_) -> Double
r
sampler :: Sampler a -> IO a
sampler :: forall a. Sampler a -> IO a
sampler Sampler a
m = forall (m :: * -> *). MonadIO m => m StdGen
newStdGen forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (forall a. Sampler a -> Tree -> a
runSampler Sampler a
m forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall g. RandomGen g => g -> Tree
randomTree forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadIO m => m StdGen
getStdGen)
independent :: Monad m => m a -> m [a]
independent :: forall (m :: * -> *) a. Monad m => m a -> m [a]
independent = forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> [a]
repeat
weightedsamples :: Weighted Sampler a -> IO [(a, Log Double)]
weightedsamples :: forall a. Weighted Sampler a -> IO [(a, Log Double)]
weightedsamples = forall a. Sampler a -> IO a
sampler forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => m a -> m [a]
independent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted