module Control.Monad.Bayes.Inference.Lazy.WIS where

import Control.Monad.Bayes.Sampler.Lazy (SamplerT, weightedSamples)
import Control.Monad.Bayes.Weighted (WeightedT)
import Numeric.Log (Log (Exp))
import System.Random (Random (randoms), getStdGen, newStdGen)

-- | Weighted Importance Sampling

-- | Likelihood weighted importance sampling first draws n weighted samples,
--    and then samples a stream of results from that regarded as an empirical distribution
lwis :: Int -> WeightedT (SamplerT IO) a -> IO [a]
lwis :: forall a. Int -> WeightedT (SamplerT IO) a -> IO [a]
lwis Int
n WeightedT (SamplerT IO) a
m = do
  [(a, Log Double)]
xws <- WeightedT (SamplerT IO) a -> IO [(a, Log Double)]
forall (m :: * -> *) a.
MonadIO m =>
WeightedT (SamplerT m) a -> m [(a, Log Double)]
weightedSamples WeightedT (SamplerT IO) a
m
  let xws' :: [(a, Log Double)]
xws' = Int -> [(a, Log Double)] -> [(a, Log Double)]
forall a. Int -> [a] -> [a]
take Int
n ([(a, Log Double)] -> [(a, Log Double)])
-> [(a, Log Double)] -> [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ [(a, Log Double)] -> Log Double -> [(a, Log Double)]
forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate [(a, Log Double)]
xws Log Double
0
  let max' :: Log Double
max' = (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd ((a, Log Double) -> Log Double) -> (a, Log Double) -> Log Double
forall a b. (a -> b) -> a -> b
$ [(a, Log Double)] -> (a, Log Double)
forall a. HasCallStack => [a] -> a
last [(a, Log Double)]
xws'
  StdGen
_ <- IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
  [Double]
rs <- StdGen -> [Double]
forall g. RandomGen g => g -> [Double]
forall a g. (Random a, RandomGen g) => g -> [a]
randoms (StdGen -> [Double]) -> IO StdGen -> IO [Double]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
  [a] -> IO [a]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([a] -> IO [a]) -> [a] -> IO [a]
forall a b. (a -> b) -> a -> b
$ (Double -> a) -> [Double] -> [a]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Double
r -> (a, Log Double) -> a
forall a b. (a, b) -> a
fst ((a, Log Double) -> a) -> (a, Log Double) -> a
forall a b. (a -> b) -> a -> b
$ [(a, Log Double)] -> (a, Log Double)
forall a. HasCallStack => [a] -> a
head ([(a, Log Double)] -> (a, Log Double))
-> [(a, Log Double)] -> (a, Log Double)
forall a b. (a -> b) -> a -> b
$ ((a, Log Double) -> Bool) -> [(a, Log Double)] -> [(a, Log Double)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Log Double -> Log Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double -> Log Double
forall a. a -> Log a
Exp (Double -> Double
forall a. Floating a => a -> a
log Double
r) Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
max') (Log Double -> Bool)
-> ((a, Log Double) -> Log Double) -> (a, Log Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd) [(a, Log Double)]
xws') [Double]
rs
  where
    accumulate :: (Num t) => [(a, t)] -> t -> [(a, t)]
    accumulate :: forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate ((a
x, t
w) : [(a, t)]
xws) t
a = (a
x, t
w t -> t -> t
forall a. Num a => a -> a -> a
+ t
a) (a, t) -> [(a, t)] -> [(a, t)]
forall a. a -> [a] -> [a]
: (a
x, t
w t -> t -> t
forall a. Num a => a -> a -> a
+ t
a) (a, t) -> [(a, t)] -> [(a, t)]
forall a. a -> [a] -> [a]
: [(a, t)] -> t -> [(a, t)]
forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate [(a, t)]
xws (t
w t -> t -> t
forall a. Num a => a -> a -> a
+ t
a)
    accumulate [] t
_ = []