-- |
-- Module      : Control.Monad.Bayes.Population
-- Description : Representation of distributions using multiple samples
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'Population' turns a single sample into a collection of weighted samples.
module Control.Monad.Bayes.Population
  ( Population,
    runPopulation,
    explicitPopulation,
    fromWeightedList,
    spawn,
    resampleMultinomial,
    resampleSystematic,
    extractEvidence,
    pushEvidence,
    proper,
    evidence,
    collapse,
    mapPopulation,
    normalize,
    popAvg,
    flatten,
    hoist,
  )
where

import Control.Arrow (second)
import Control.Monad (replicateM)
import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Weighted hiding (flatten, hoist)
import Control.Monad.Trans
import Control.Monad.Trans.List
import qualified Data.List
import qualified Data.Vector as V
import Numeric.Log
import Prelude hiding (all, sum)

-- | A collection of weighted samples, or particles.
newtype Population m a = Population (Weighted (ListT m) a)
  deriving (a -> Population m b -> Population m a
(a -> b) -> Population m a -> Population m b
(forall a b. (a -> b) -> Population m a -> Population m b)
-> (forall a b. a -> Population m b -> Population m a)
-> Functor (Population m)
forall a b. a -> Population m b -> Population m a
forall a b. (a -> b) -> Population m a -> Population m b
forall (m :: * -> *) a b.
Functor m =>
a -> Population m b -> Population m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> Population m a -> Population m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Population m b -> Population m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> Population m b -> Population m a
fmap :: (a -> b) -> Population m a -> Population m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> Population m a -> Population m b
Functor, Functor (Population m)
a -> Population m a
Functor (Population m) =>
(forall a. a -> Population m a)
-> (forall a b.
    Population m (a -> b) -> Population m a -> Population m b)
-> (forall a b c.
    (a -> b -> c)
    -> Population m a -> Population m b -> Population m c)
-> (forall a b. Population m a -> Population m b -> Population m b)
-> (forall a b. Population m a -> Population m b -> Population m a)
-> Applicative (Population m)
Population m a -> Population m b -> Population m b
Population m a -> Population m b -> Population m a
Population m (a -> b) -> Population m a -> Population m b
(a -> b -> c) -> Population m a -> Population m b -> Population m c
forall a. a -> Population m a
forall a b. Population m a -> Population m b -> Population m a
forall a b. Population m a -> Population m b -> Population m b
forall a b.
Population m (a -> b) -> Population m a -> Population m b
forall a b c.
(a -> b -> c) -> Population m a -> Population m b -> Population m c
forall (m :: * -> *). Monad m => Functor (Population m)
forall (m :: * -> *) a. Monad m => a -> Population m a
forall (m :: * -> *) a b.
Monad m =>
Population m a -> Population m b -> Population m a
forall (m :: * -> *) a b.
Monad m =>
Population m a -> Population m b -> Population m b
forall (m :: * -> *) a b.
Monad m =>
Population m (a -> b) -> Population m a -> Population m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Population m a -> Population m b -> Population m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: Population m a -> Population m b -> Population m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
Population m a -> Population m b -> Population m a
*> :: Population m a -> Population m b -> Population m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
Population m a -> Population m b -> Population m b
liftA2 :: (a -> b -> c) -> Population m a -> Population m b -> Population m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Population m a -> Population m b -> Population m c
<*> :: Population m (a -> b) -> Population m a -> Population m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
Population m (a -> b) -> Population m a -> Population m b
pure :: a -> Population m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> Population m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (Population m)
Applicative, Applicative (Population m)
a -> Population m a
Applicative (Population m) =>
(forall a b.
 Population m a -> (a -> Population m b) -> Population m b)
-> (forall a b. Population m a -> Population m b -> Population m b)
-> (forall a. a -> Population m a)
-> Monad (Population m)
Population m a -> (a -> Population m b) -> Population m b
Population m a -> Population m b -> Population m b
forall a. a -> Population m a
forall a b. Population m a -> Population m b -> Population m b
forall a b.
Population m a -> (a -> Population m b) -> Population m b
forall (m :: * -> *). Monad m => Applicative (Population m)
forall (m :: * -> *) a. Monad m => a -> Population m a
forall (m :: * -> *) a b.
Monad m =>
Population m a -> Population m b -> Population m b
forall (m :: * -> *) a b.
Monad m =>
Population m a -> (a -> Population m b) -> Population m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> Population m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> Population m a
>> :: Population m a -> Population m b -> Population m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
Population m a -> Population m b -> Population m b
>>= :: Population m a -> (a -> Population m b) -> Population m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
Population m a -> (a -> Population m b) -> Population m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (Population m)
Monad, Monad (Population m)
Monad (Population m) =>
(forall a. IO a -> Population m a) -> MonadIO (Population m)
IO a -> Population m a
forall a. IO a -> Population m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (Population m)
forall (m :: * -> *) a. MonadIO m => IO a -> Population m a
liftIO :: IO a -> Population m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> Population m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (Population m)
MonadIO, Monad (Population m)
Population m Double
v Double -> Population m Int
v (Log Double) -> Population m Int
v Double -> Population m (v Double)
Monad (Population m) =>
Population m Double
-> (Double -> Double -> Population m Double)
-> (Double -> Double -> Population m Double)
-> (Double -> Double -> Population m Double)
-> (Double -> Double -> Population m Double)
-> (Double -> Population m Bool)
-> (forall (v :: * -> *).
    Vector v Double =>
    v Double -> Population m Int)
-> (forall (v :: * -> *).
    (Vector v (Log Double), Vector v Double) =>
    v (Log Double) -> Population m Int)
-> (forall a. [a] -> Population m a)
-> (Double -> Population m Int)
-> (Double -> Population m Int)
-> (forall (v :: * -> *).
    Vector v Double =>
    v Double -> Population m (v Double))
-> MonadSample (Population m)
Double -> Population m Bool
Double -> Population m Int
Double -> Double -> Population m Double
[a] -> Population m a
forall a. [a] -> Population m a
forall (m :: * -> *).
Monad m =>
m Double
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> m Bool)
-> (forall (v :: * -> *). Vector v Double => v Double -> m Int)
-> (forall (v :: * -> *).
    (Vector v (Log Double), Vector v Double) =>
    v (Log Double) -> m Int)
-> (forall a. [a] -> m a)
-> (Double -> m Int)
-> (Double -> m Int)
-> (forall (v :: * -> *).
    Vector v Double =>
    v Double -> m (v Double))
-> MonadSample m
forall (v :: * -> *).
Vector v Double =>
v Double -> Population m (v Double)
forall (v :: * -> *).
Vector v Double =>
v Double -> Population m Int
forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Population m Int
forall (m :: * -> *). MonadSample m => Monad (Population m)
forall (m :: * -> *). MonadSample m => Population m Double
forall (m :: * -> *). MonadSample m => Double -> Population m Bool
forall (m :: * -> *). MonadSample m => Double -> Population m Int
forall (m :: * -> *).
MonadSample m =>
Double -> Double -> Population m Double
forall (m :: * -> *) a. MonadSample m => [a] -> Population m a
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> Population m (v Double)
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> Population m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Population m Int
dirichlet :: v Double -> Population m (v Double)
$cdirichlet :: forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> Population m (v Double)
poisson :: Double -> Population m Int
$cpoisson :: forall (m :: * -> *). MonadSample m => Double -> Population m Int
geometric :: Double -> Population m Int
$cgeometric :: forall (m :: * -> *). MonadSample m => Double -> Population m Int
uniformD :: [a] -> Population m a
$cuniformD :: forall (m :: * -> *) a. MonadSample m => [a] -> Population m a
logCategorical :: v (Log Double) -> Population m Int
$clogCategorical :: forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Population m Int
categorical :: v Double -> Population m Int
$ccategorical :: forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> Population m Int
bernoulli :: Double -> Population m Bool
$cbernoulli :: forall (m :: * -> *). MonadSample m => Double -> Population m Bool
beta :: Double -> Double -> Population m Double
$cbeta :: forall (m :: * -> *).
MonadSample m =>
Double -> Double -> Population m Double
gamma :: Double -> Double -> Population m Double
$cgamma :: forall (m :: * -> *).
MonadSample m =>
Double -> Double -> Population m Double
normal :: Double -> Double -> Population m Double
$cnormal :: forall (m :: * -> *).
MonadSample m =>
Double -> Double -> Population m Double
uniform :: Double -> Double -> Population m Double
$cuniform :: forall (m :: * -> *).
MonadSample m =>
Double -> Double -> Population m Double
random :: Population m Double
$crandom :: forall (m :: * -> *). MonadSample m => Population m Double
$cp1MonadSample :: forall (m :: * -> *). MonadSample m => Monad (Population m)
MonadSample, Monad (Population m)
Monad (Population m) =>
(Log Double -> Population m ()) -> MonadCond (Population m)
Log Double -> Population m ()
forall (m :: * -> *). Monad m => Monad (Population m)
forall (m :: * -> *). Monad m => Log Double -> Population m ()
forall (m :: * -> *).
Monad m =>
(Log Double -> m ()) -> MonadCond m
score :: Log Double -> Population m ()
$cscore :: forall (m :: * -> *). Monad m => Log Double -> Population m ()
$cp1MonadCond :: forall (m :: * -> *). Monad m => Monad (Population m)
MonadCond, MonadCond (Population m)
MonadSample (Population m)
(MonadSample (Population m), MonadCond (Population m)) =>
MonadInfer (Population m)
forall (m :: * -> *). MonadSample m => MonadCond (Population m)
forall (m :: * -> *). MonadSample m => MonadSample (Population m)
forall (m :: * -> *). (MonadSample m, MonadCond m) => MonadInfer m
$cp2MonadInfer :: forall (m :: * -> *). MonadSample m => MonadCond (Population m)
$cp1MonadInfer :: forall (m :: * -> *). MonadSample m => MonadSample (Population m)
MonadInfer)

instance MonadTrans Population where
  lift :: m a -> Population m a
lift = Weighted (ListT m) a -> Population m a
forall (m :: * -> *) a. Weighted (ListT m) a -> Population m a
Population (Weighted (ListT m) a -> Population m a)
-> (m a -> Weighted (ListT m) a) -> m a -> Population m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ListT m a -> Weighted (ListT m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ListT m a -> Weighted (ListT m) a)
-> (m a -> ListT m a) -> m a -> Weighted (ListT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> ListT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

-- | Explicit representation of the weighted sample with weights in the log
-- domain.
runPopulation :: Functor m => Population m a -> m [(a, Log Double)]
runPopulation :: Population m a -> m [(a, Log Double)]
runPopulation (Population m :: Weighted (ListT m) a
m) = ListT m (a, Log Double) -> m [(a, Log Double)]
forall (m :: * -> *) a. ListT m a -> m [a]
runListT (ListT m (a, Log Double) -> m [(a, Log Double)])
-> ListT m (a, Log Double) -> m [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ Weighted (ListT m) a -> ListT m (a, Log Double)
forall (m :: * -> *) a.
Functor m =>
Weighted m a -> m (a, Log Double)
runWeighted Weighted (ListT m) a
m

-- | Explicit representation of the weighted sample.
explicitPopulation :: Functor m => Population m a -> m [(a, Double)]
explicitPopulation :: Population m a -> m [(a, Double)]
explicitPopulation = ([(a, Log Double)] -> [(a, Double)])
-> m [(a, Log Double)] -> m [(a, Double)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((a, Log Double) -> (a, Double))
-> [(a, Log Double)] -> [(a, Double)]
forall a b. (a -> b) -> [a] -> [b]
map ((Log Double -> Double) -> (a, Log Double) -> (a, Double)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln))) (m [(a, Log Double)] -> m [(a, Double)])
-> (Population m a -> m [(a, Log Double)])
-> Population m a
-> m [(a, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m a -> m [(a, Log Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Log Double)]
runPopulation

-- | Initialize 'Population' with a concrete weighted sample.
fromWeightedList :: Monad m => m [(a, Log Double)] -> Population m a
fromWeightedList :: m [(a, Log Double)] -> Population m a
fromWeightedList = Weighted (ListT m) a -> Population m a
forall (m :: * -> *) a. Weighted (ListT m) a -> Population m a
Population (Weighted (ListT m) a -> Population m a)
-> (m [(a, Log Double)] -> Weighted (ListT m) a)
-> m [(a, Log Double)]
-> Population m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ListT m (a, Log Double) -> Weighted (ListT m) a
forall (m :: * -> *) a.
Monad m =>
m (a, Log Double) -> Weighted m a
withWeight (ListT m (a, Log Double) -> Weighted (ListT m) a)
-> (m [(a, Log Double)] -> ListT m (a, Log Double))
-> m [(a, Log Double)]
-> Weighted (ListT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m [(a, Log Double)] -> ListT m (a, Log Double)
forall (m :: * -> *) a. m [a] -> ListT m a
ListT

-- | Increase the sample size by a given factor.
-- The weights are adjusted such that their sum is preserved.
-- It is therefore safe to use 'spawn' in arbitrary places in the program
-- without introducing bias.
spawn :: Monad m => Int -> Population m ()
spawn :: Int -> Population m ()
spawn n :: Int
n = m [((), Log Double)] -> Population m ()
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> Population m a
fromWeightedList (m [((), Log Double)] -> Population m ())
-> m [((), Log Double)] -> Population m ()
forall a b. (a -> b) -> a -> b
$ [((), Log Double)] -> m [((), Log Double)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([((), Log Double)] -> m [((), Log Double)])
-> [((), Log Double)] -> m [((), Log Double)]
forall a b. (a -> b) -> a -> b
$ Int -> ((), Log Double) -> [((), Log Double)]
forall a. Int -> a -> [a]
replicate Int
n ((), 1 Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)

resampleGeneric ::
  MonadSample m =>
  -- | resampler
  (V.Vector Double -> m [Int]) ->
  Population m a ->
  Population m a
resampleGeneric :: (Vector Double -> m [Int]) -> Population m a -> Population m a
resampleGeneric resampler :: Vector Double -> m [Int]
resampler m :: Population m a
m = m [(a, Log Double)] -> Population m a
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> Population m a
fromWeightedList (m [(a, Log Double)] -> Population m a)
-> m [(a, Log Double)] -> Population m a
forall a b. (a -> b) -> a -> b
$ do
  [(a, Log Double)]
pop <- Population m a -> m [(a, Log Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Log Double)]
runPopulation Population m a
m
  let (xs :: [a]
xs, ps :: [Log Double]
ps) = [(a, Log Double)] -> ([a], [Log Double])
forall a b. [(a, b)] -> ([a], [b])
unzip [(a, Log Double)]
pop
  let n :: Int
n = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs
  let z :: Log Double
z = [Log Double] -> Log Double
forall a (f :: * -> *).
(RealFloat a, Foldable f) =>
f (Log a) -> Log a
sum [Log Double]
ps
  if Log Double
z Log Double -> Log Double -> Bool
forall a. Ord a => a -> a -> Bool
> 0
    then do
      let weights :: Vector Double
weights = [Double] -> Vector Double
forall a. [a] -> Vector a
V.fromList ((Log Double -> Double) -> [Log Double] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln (Log Double -> Double)
-> (Log Double -> Log Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Log Double
z)) [Log Double]
ps)
      [Int]
ancestors <- Vector Double -> m [Int]
resampler Vector Double
weights
      let xvec :: Vector a
xvec = [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
xs
      let offsprings :: [a]
offsprings = (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Vector a
xvec Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.!) [Int]
ancestors
      [(a, Log Double)] -> m [(a, Log Double)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(a, Log Double)] -> m [(a, Log Double)])
-> [(a, Log Double)] -> m [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ (a -> (a, Log Double)) -> [a] -> [(a, Log Double)]
forall a b. (a -> b) -> [a] -> [b]
map (,Log Double
z Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) [a]
offsprings
    else-- if all weights are zero do not resample
      [(a, Log Double)] -> m [(a, Log Double)]
forall (m :: * -> *) a. Monad m => a -> m a
return [(a, Log Double)]
pop

-- | Systematic resampling helper.
systematic :: Double -> V.Vector Double -> [Int]
systematic :: Double -> Vector Double -> [Int]
systematic u :: Double
u ps :: Vector Double
ps = Int -> Double -> Int -> Double -> [Int] -> [Int]
f 0 (Double
u Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) 0 0 []
  where
    prob :: Int -> Double
prob i :: Int
i = Vector Double
ps Vector Double -> Int -> Double
forall a. Vector a -> Int -> a
V.! Int
i
    n :: Int
n = Vector Double -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Vector Double
ps
    inc :: Double
inc = 1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
    f :: Int -> Double -> Int -> Double -> [Int] -> [Int]
f i :: Int
i _ _ _ acc :: [Int]
acc | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = [Int]
acc
    f i :: Int
i v :: Double
v j :: Int
j q :: Double
q acc :: [Int]
acc =
      if Double
v Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
q
        then Int -> Double -> Int -> Double -> [Int] -> [Int]
f (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) (Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
inc) Int
j Double
q (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
-1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
acc)
        else Int -> Double -> Int -> Double -> [Int] -> [Int]
f Int
i Double
v (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) (Double
q Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Int -> Double
prob Int
j) [Int]
acc

-- | Resample the population using the underlying monad and a systematic resampling scheme.
-- The total weight is preserved.
resampleSystematic ::
  (MonadSample m) =>
  Population m a ->
  Population m a
resampleSystematic :: Population m a -> Population m a
resampleSystematic = (Vector Double -> m [Int]) -> Population m a -> Population m a
forall (m :: * -> *) a.
MonadSample m =>
(Vector Double -> m [Int]) -> Population m a -> Population m a
resampleGeneric (\ps :: Vector Double
ps -> (Double -> Vector Double -> [Int]
`systematic` Vector Double
ps) (Double -> [Int]) -> m Double -> m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Double
forall (m :: * -> *). MonadSample m => m Double
random)

-- | Multinomial resampler.
multinomial :: MonadSample m => V.Vector Double -> m [Int]
multinomial :: Vector Double -> m [Int]
multinomial ps :: Vector Double
ps = Int -> m Int -> m [Int]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Vector Double -> Int
forall a. Vector a -> Int
V.length Vector Double
ps) (Vector Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical Vector Double
ps)

-- | Resample the population using the underlying monad and a multinomial resampling scheme.
-- The total weight is preserved.
resampleMultinomial ::
  (MonadSample m) =>
  Population m a ->
  Population m a
resampleMultinomial :: Population m a -> Population m a
resampleMultinomial = (Vector Double -> m [Int]) -> Population m a -> Population m a
forall (m :: * -> *) a.
MonadSample m =>
(Vector Double -> m [Int]) -> Population m a -> Population m a
resampleGeneric Vector Double -> m [Int]
forall (m :: * -> *). MonadSample m => Vector Double -> m [Int]
multinomial

-- | Separate the sum of weights into the 'Weighted' transformer.
-- Weights are normalized after this operation.
extractEvidence ::
  Monad m =>
  Population m a ->
  Population (Weighted m) a
extractEvidence :: Population m a -> Population (Weighted m) a
extractEvidence m :: Population m a
m = Weighted m [(a, Log Double)] -> Population (Weighted m) a
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> Population m a
fromWeightedList (Weighted m [(a, Log Double)] -> Population (Weighted m) a)
-> Weighted m [(a, Log Double)] -> Population (Weighted m) a
forall a b. (a -> b) -> a -> b
$ do
  [(a, Log Double)]
pop <- m [(a, Log Double)] -> Weighted m [(a, Log Double)]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m [(a, Log Double)] -> Weighted m [(a, Log Double)])
-> m [(a, Log Double)] -> Weighted m [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ Population m a -> m [(a, Log Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Log Double)]
runPopulation Population m a
m
  let (xs :: [a]
xs, ps :: [Log Double]
ps) = [(a, Log Double)] -> ([a], [Log Double])
forall a b. [(a, b)] -> ([a], [b])
unzip [(a, Log Double)]
pop
  let z :: Log Double
z = [Log Double] -> Log Double
forall a (f :: * -> *).
(RealFloat a, Foldable f) =>
f (Log a) -> Log a
sum [Log Double]
ps
  let ws :: [Log Double]
ws = (Log Double -> Log Double) -> [Log Double] -> [Log Double]
forall a b. (a -> b) -> [a] -> [b]
map (if Log Double
z Log Double -> Log Double -> Bool
forall a. Ord a => a -> a -> Bool
> 0 then (Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Log Double
z) else Log Double -> Log Double -> Log Double
forall a b. a -> b -> a
const (1 Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Log Double] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Log Double]
ps))) [Log Double]
ps
  Log Double -> Weighted m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
factor Log Double
z
  [(a, Log Double)] -> Weighted m [(a, Log Double)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(a, Log Double)] -> Weighted m [(a, Log Double)])
-> [(a, Log Double)] -> Weighted m [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ [a] -> [Log Double] -> [(a, Log Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs [Log Double]
ws

-- | Push the evidence estimator as a score to the transformed monad.
-- Weights are normalized after this operation.
pushEvidence ::
  MonadCond m =>
  Population m a ->
  Population m a
pushEvidence :: Population m a -> Population m a
pushEvidence = (forall x. Weighted m x -> m x)
-> Population (Weighted m) a -> Population m a
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> Population m a -> Population n a
hoist forall x. Weighted m x -> m x
forall (m :: * -> *) a. MonadCond m => Weighted m a -> m a
applyWeight (Population (Weighted m) a -> Population m a)
-> (Population m a -> Population (Weighted m) a)
-> Population m a
-> Population m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m a -> Population (Weighted m) a
forall (m :: * -> *) a.
Monad m =>
Population m a -> Population (Weighted m) a
extractEvidence

-- | A properly weighted single sample, that is one picked at random according
-- to the weights, with the sum of all weights.
proper ::
  (MonadSample m) =>
  Population m a ->
  Weighted m a
proper :: Population m a -> Weighted m a
proper m :: Population m a
m = do
  [(a, Log Double)]
pop <- Population (Weighted m) a -> Weighted m [(a, Log Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Log Double)]
runPopulation (Population (Weighted m) a -> Weighted m [(a, Log Double)])
-> Population (Weighted m) a -> Weighted m [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ Population m a -> Population (Weighted m) a
forall (m :: * -> *) a.
Monad m =>
Population m a -> Population (Weighted m) a
extractEvidence Population m a
m
  let (xs :: [a]
xs, ps :: [Log Double]
ps) = [(a, Log Double)] -> ([a], [Log Double])
forall a b. [(a, b)] -> ([a], [b])
unzip [(a, Log Double)]
pop
  Int
index <- Vector (Log Double) -> Weighted m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> m Int
logCategorical (Vector (Log Double) -> Weighted m Int)
-> Vector (Log Double) -> Weighted m Int
forall a b. (a -> b) -> a -> b
$ [Log Double] -> Vector (Log Double)
forall a. [a] -> Vector a
V.fromList [Log Double]
ps
  let x :: a
x = [a]
xs [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int
index
  a -> Weighted m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- | Model evidence estimator, also known as pseudo-marginal likelihood.
evidence :: (Monad m) => Population m a -> m (Log Double)
evidence :: Population m a -> m (Log Double)
evidence = Weighted m [(a, Log Double)] -> m (Log Double)
forall (m :: * -> *) a. Functor m => Weighted m a -> m (Log Double)
extractWeight (Weighted m [(a, Log Double)] -> m (Log Double))
-> (Population m a -> Weighted m [(a, Log Double)])
-> Population m a
-> m (Log Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population (Weighted m) a -> Weighted m [(a, Log Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Log Double)]
runPopulation (Population (Weighted m) a -> Weighted m [(a, Log Double)])
-> (Population m a -> Population (Weighted m) a)
-> Population m a
-> Weighted m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m a -> Population (Weighted m) a
forall (m :: * -> *) a.
Monad m =>
Population m a -> Population (Weighted m) a
extractEvidence

-- | Picks one point from the population and uses model evidence as a 'score'
-- in the transformed monad.
-- This way a single sample can be selected from a population without
-- introducing bias.
collapse ::
  (MonadInfer m) =>
  Population m a ->
  m a
collapse :: Population m a -> m a
collapse = Weighted m a -> m a
forall (m :: * -> *) a. MonadCond m => Weighted m a -> m a
applyWeight (Weighted m a -> m a)
-> (Population m a -> Weighted m a) -> Population m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m a -> Weighted m a
forall (m :: * -> *) a.
MonadSample m =>
Population m a -> Weighted m a
proper

-- | Applies a random transformation to a population.
mapPopulation ::
  (Monad m) =>
  ([(a, Log Double)] -> m [(a, Log Double)]) ->
  Population m a ->
  Population m a
mapPopulation :: ([(a, Log Double)] -> m [(a, Log Double)])
-> Population m a -> Population m a
mapPopulation f :: [(a, Log Double)] -> m [(a, Log Double)]
f m :: Population m a
m = m [(a, Log Double)] -> Population m a
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> Population m a
fromWeightedList (m [(a, Log Double)] -> Population m a)
-> m [(a, Log Double)] -> Population m a
forall a b. (a -> b) -> a -> b
$ Population m a -> m [(a, Log Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Log Double)]
runPopulation Population m a
m m [(a, Log Double)]
-> ([(a, Log Double)] -> m [(a, Log Double)])
-> m [(a, Log Double)]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [(a, Log Double)] -> m [(a, Log Double)]
f

-- | Normalizes the weights in the population so that their sum is 1.
-- This transformation introduces bias.
normalize :: (Monad m) => Population m a -> Population m a
normalize :: Population m a -> Population m a
normalize = (forall x. Weighted m x -> m x)
-> Population (Weighted m) a -> Population m a
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> Population m a -> Population n a
hoist forall x. Weighted m x -> m x
forall (m :: * -> *) a. Functor m => Weighted m a -> m a
prior (Population (Weighted m) a -> Population m a)
-> (Population m a -> Population (Weighted m) a)
-> Population m a
-> Population m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m a -> Population (Weighted m) a
forall (m :: * -> *) a.
Monad m =>
Population m a -> Population (Weighted m) a
extractEvidence

-- | Population average of a function, computed using unnormalized weights.
popAvg :: (Monad m) => (a -> Double) -> Population m a -> m Double
popAvg :: (a -> Double) -> Population m a -> m Double
popAvg f :: a -> Double
f p :: Population m a
p = do
  [(a, Double)]
xs <- Population m a -> m [(a, Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Double)]
explicitPopulation Population m a
p
  let ys :: [Double]
ys = ((a, Double) -> Double) -> [(a, Double)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (\(x :: a
x, w :: Double
w) -> a -> Double
f a
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
w) [(a, Double)]
xs
  let t :: Double
t = [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Data.List.sum [Double]
ys
  Double -> m Double
forall (m :: * -> *) a. Monad m => a -> m a
return Double
t

-- | Combine a population of populations into a single population.
flatten :: Monad m => Population (Population m) a -> Population m a
flatten :: Population (Population m) a -> Population m a
flatten m :: Population (Population m) a
m = Weighted (ListT m) a -> Population m a
forall (m :: * -> *) a. Weighted (ListT m) a -> Population m a
Population (Weighted (ListT m) a -> Population m a)
-> Weighted (ListT m) a -> Population m a
forall a b. (a -> b) -> a -> b
$ ListT m (a, Log Double) -> Weighted (ListT m) a
forall (m :: * -> *) a.
Monad m =>
m (a, Log Double) -> Weighted m a
withWeight (ListT m (a, Log Double) -> Weighted (ListT m) a)
-> ListT m (a, Log Double) -> Weighted (ListT m) a
forall a b. (a -> b) -> a -> b
$ m [(a, Log Double)] -> ListT m (a, Log Double)
forall (m :: * -> *) a. m [a] -> ListT m a
ListT m [(a, Log Double)]
t
  where
    t :: m [(a, Log Double)]
t = [([(a, Log Double)], Log Double)] -> [(a, Log Double)]
forall (m :: * -> *) b a.
(Monad m, Num b) =>
m (m (a, b), b) -> m (a, b)
f ([([(a, Log Double)], Log Double)] -> [(a, Log Double)])
-> m [([(a, Log Double)], Log Double)] -> m [(a, Log Double)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Population m [(a, Log Double)]
-> m [([(a, Log Double)], Log Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Log Double)]
runPopulation (Population m [(a, Log Double)]
 -> m [([(a, Log Double)], Log Double)])
-> (Population (Population m) a -> Population m [(a, Log Double)])
-> Population (Population m) a
-> m [([(a, Log Double)], Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population (Population m) a -> Population m [(a, Log Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Log Double)]
runPopulation) Population (Population m) a
m
    f :: m (m (a, b), b) -> m (a, b)
f d :: m (m (a, b), b)
d = do
      (x :: m (a, b)
x, p :: b
p) <- m (m (a, b), b)
d
      (y :: a
y, q :: b
q) <- m (a, b)
x
      (a, b) -> m (a, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
y, b
p b -> b -> b
forall a. Num a => a -> a -> a
* b
q)

-- | Applies a transformation to the inner monad.
hoist ::
  (Monad m, Monad n) =>
  (forall x. m x -> n x) ->
  Population m a ->
  Population n a
hoist :: (forall x. m x -> n x) -> Population m a -> Population n a
hoist f :: forall x. m x -> n x
f = n [(a, Log Double)] -> Population n a
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> Population m a
fromWeightedList (n [(a, Log Double)] -> Population n a)
-> (Population m a -> n [(a, Log Double)])
-> Population m a
-> Population n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m [(a, Log Double)] -> n [(a, Log Double)]
forall x. m x -> n x
f (m [(a, Log Double)] -> n [(a, Log Double)])
-> (Population m a -> m [(a, Log Double)])
-> Population m a
-> n [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m a -> m [(a, Log Double)]
forall (m :: * -> *) a.
Functor m =>
Population m a -> m [(a, Log Double)]
runPopulation