module Control.Monad.Bayes.Inference.Lazy.WIS where
import Control.Monad.Bayes.Sampler.Lazy (Sampler, weightedsamples)
import Control.Monad.Bayes.Weighted (Weighted)
import Numeric.Log (Log (Exp))
import System.Random (Random (randoms), getStdGen, newStdGen)
lwis :: Int -> Weighted Sampler a -> IO [a]
lwis :: forall a. Int -> Weighted Sampler a -> IO [a]
lwis Int
n Weighted Sampler a
m = do
[(a, Log Double)]
xws <- forall a. Weighted Sampler a -> IO [(a, Log Double)]
weightedsamples Weighted Sampler a
m
let xws' :: [(a, Log Double)]
xws' = forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate [(a, Log Double)]
xws Log Double
0
let max' :: Log Double
max' = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [(a, Log Double)]
xws'
StdGen
_ <- forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
[Double]
rs <- forall a g. (Random a, RandomGen g) => g -> [a]
randoms forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Double
r -> forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> a -> Bool
>= forall a. a -> Log a
Exp (forall a. Floating a => a -> a
log Double
r) forall a. Num a => a -> a -> a
* Log Double
max') forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(a, Log Double)]
xws') [Double]
rs
where
accumulate :: Num t => [(a, t)] -> t -> [(a, t)]
accumulate :: forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate ((a
x, t
w) : [(a, t)]
xws) t
a = (a
x, t
w forall a. Num a => a -> a -> a
+ t
a) forall a. a -> [a] -> [a]
: (a
x, t
w forall a. Num a => a -> a -> a
+ t
a) forall a. a -> [a] -> [a]
: forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate [(a, t)]
xws (t
w forall a. Num a => a -> a -> a
+ t
a)
accumulate [] t
_ = []