module Control.Monad.Bayes.Inference.Lazy.WIS where

import Control.Monad (guard)
import Control.Monad.Bayes.Sampler.Lazy (SamplerT, weightedSamples)
import Control.Monad.Bayes.Weighted (WeightedT)
import Data.Maybe (mapMaybe)
import Numeric.Log (Log (Exp))
import System.Random (Random (randoms), getStdGen, newStdGen)

-- | Weighted Importance Sampling

-- | Likelihood weighted importance sampling first draws n weighted samples,
--    and then samples a stream of results from that regarded as an empirical distribution
lwis :: Int -> WeightedT (SamplerT IO) a -> IO [a]
lwis :: forall a. Int -> WeightedT (SamplerT IO) a -> IO [a]
lwis Int
n WeightedT (SamplerT IO) a
m = do
  [(a, Log Double)]
xws <- WeightedT (SamplerT IO) a -> IO [(a, Log Double)]
forall (m :: * -> *) a.
MonadIO m =>
WeightedT (SamplerT m) a -> m [(a, Log Double)]
weightedSamples WeightedT (SamplerT IO) a
m
  let xws' :: [(a, Log Double)]
xws' = Int -> [(a, Log Double)] -> [(a, Log Double)]
forall a. Int -> [a] -> [a]
take Int
n ([(a, Log Double)] -> [(a, Log Double)])
-> [(a, Log Double)] -> [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ [(a, Log Double)] -> Log Double -> [(a, Log Double)]
forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate [(a, Log Double)]
xws Log Double
0
  let max' :: Log Double
max' = (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd ((a, Log Double) -> Log Double) -> (a, Log Double) -> Log Double
forall a b. (a -> b) -> a -> b
$ [(a, Log Double)] -> (a, Log Double)
forall a. HasCallStack => [a] -> a
last [(a, Log Double)]
xws'
  StdGen
_ <- IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
  [Double]
rs <- StdGen -> [Double]
forall g. RandomGen g => g -> [Double]
forall a g. (Random a, RandomGen g) => g -> [a]
randoms (StdGen -> [Double]) -> IO StdGen -> IO [Double]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
  [a] -> IO [a]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([a] -> IO [a]) -> [a] -> IO [a]
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
1 ([a] -> [a]) -> [[a]] -> [a]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Double -> [a]) -> [Double] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Double
r -> ((a, Log Double) -> Maybe a) -> [(a, Log Double)] -> [a]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(a
a, Log Double
p) -> Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Log Double
p Log Double -> Log Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double -> Log Double
forall a. a -> Log a
Exp (Double -> Double
forall a. Floating a => a -> a
log Double
r) Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
max') Maybe () -> Maybe a -> Maybe a
forall a b. Maybe a -> Maybe b -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> Maybe a
forall a. a -> Maybe a
Just a
a) [(a, Log Double)]
xws') [Double]
rs
  where
    accumulate :: (Num t) => [(a, t)] -> t -> [(a, t)]
    accumulate :: forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate ((a
x, t
w) : [(a, t)]
xws) t
a = (a
x, t
w t -> t -> t
forall a. Num a => a -> a -> a
+ t
a) (a, t) -> [(a, t)] -> [(a, t)]
forall a. a -> [a] -> [a]
: (a
x, t
w t -> t -> t
forall a. Num a => a -> a -> a
+ t
a) (a, t) -> [(a, t)] -> [(a, t)]
forall a. a -> [a] -> [a]
: [(a, t)] -> t -> [(a, t)]
forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate [(a, t)]
xws (t
w t -> t -> t
forall a. Num a => a -> a -> a
+ t
a)
    accumulate [] t
_ = []