{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Monad.Bayes.Sampler.Lazy where
import Control.Monad (ap)
import Control.Monad.Bayes.Class (MonadDistribution (random))
import Control.Monad.Bayes.Weighted (WeightedT, runWeightedT)
import Numeric.Log (Log (..))
import System.Random
( RandomGen (split),
getStdGen,
newStdGen,
)
import System.Random qualified as R
data Tree = Tree
{ Tree -> Double
currentUniform :: Double,
Tree -> Trees
lazyUniforms :: Trees
}
data Trees = Trees
{ Trees -> Tree
headTree :: Tree,
Trees -> Trees
tailTrees :: Trees
}
newtype SamplerT a = SamplerT {forall a. SamplerT a -> Tree -> a
runSamplerT :: Tree -> a}
deriving (forall a b. a -> SamplerT b -> SamplerT a
forall a b. (a -> b) -> SamplerT a -> SamplerT b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> SamplerT b -> SamplerT a
$c<$ :: forall a b. a -> SamplerT b -> SamplerT a
fmap :: forall a b. (a -> b) -> SamplerT a -> SamplerT b
$cfmap :: forall a b. (a -> b) -> SamplerT a -> SamplerT b
Functor)
splitTree :: Tree -> (Tree, Tree)
splitTree :: Tree -> (Tree, Tree)
splitTree (Tree Double
r (Trees Tree
t Trees
ts)) = (Tree
t, Double -> Trees -> Tree
Tree Double
r Trees
ts)
randomTree :: (RandomGen g) => g -> Tree
randomTree :: forall g. RandomGen g => g -> Tree
randomTree g
g = let (Double
a, g
g') = forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g in Double -> Trees -> Tree
Tree Double
a (forall g. RandomGen g => g -> Trees
randomTrees g
g')
randomTrees :: (RandomGen g) => g -> Trees
randomTrees :: forall g. RandomGen g => g -> Trees
randomTrees g
g = let (g
g1, g
g2) = forall g. RandomGen g => g -> (g, g)
split g
g in Tree -> Trees -> Trees
Trees (forall g. RandomGen g => g -> Tree
randomTree g
g1) (forall g. RandomGen g => g -> Trees
randomTrees g
g2)
instance Applicative SamplerT where
pure :: forall a. a -> SamplerT a
pure = forall a. (Tree -> a) -> SamplerT a
SamplerT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const
<*> :: forall a b. SamplerT (a -> b) -> SamplerT a -> SamplerT b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
instance Monad SamplerT where
return :: forall a. a -> SamplerT a
return = forall (f :: * -> *) a. Applicative f => a -> f a
pure
(SamplerT Tree -> a
m) >>= :: forall a b. SamplerT a -> (a -> SamplerT b) -> SamplerT b
>>= a -> SamplerT b
f = forall a. (Tree -> a) -> SamplerT a
SamplerT \Tree
g ->
let (Tree
g1, Tree
g2) = Tree -> (Tree, Tree)
splitTree Tree
g
(SamplerT Tree -> b
m') = a -> SamplerT b
f (Tree -> a
m Tree
g1)
in Tree -> b
m' Tree
g2
instance MonadDistribution SamplerT where
random :: SamplerT Double
random = forall a. (Tree -> a) -> SamplerT a
SamplerT \(Tree Double
r Trees
_) -> Double
r
sampler :: SamplerT a -> IO a
sampler :: forall a. SamplerT a -> IO a
sampler SamplerT a
m = forall (m :: * -> *). MonadIO m => m StdGen
newStdGen forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (forall a. SamplerT a -> Tree -> a
runSamplerT SamplerT a
m forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall g. RandomGen g => g -> Tree
randomTree forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadIO m => m StdGen
getStdGen)
independent :: (Monad m) => m a -> m [a]
independent :: forall (m :: * -> *) a. Monad m => m a -> m [a]
independent = forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> [a]
repeat
weightedsamples :: WeightedT SamplerT a -> IO [(a, Log Double)]
weightedsamples :: forall a. WeightedT SamplerT a -> IO [(a, Log Double)]
weightedsamples = forall a. SamplerT a -> IO a
sampler forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => m a -> m [a]
independent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT