module Data.MonadicStreamFunction.Bayes where

-- base
import Control.Arrow
import Data.Functor (($>))
import Data.Tuple (swap)

-- transformers

-- log-domain
import Numeric.Log hiding (sum)

-- monad-bayes
import Control.Monad.Bayes.Population

-- dunai
import Data.MonadicStreamFunction
import Data.MonadicStreamFunction.InternalCore (MSF (..))

-- | Run the Sequential Monte Carlo algorithm continuously on an 'MSF'
runPopulationS ::
  forall m a b.
  Monad m =>
  -- | Number of particles
  Int ->
  -- | Resampler
  (forall x. Population m x -> Population m x) ->
  MSF (Population m) a b ->
  -- FIXME Why not MSF m a (Population b)
  MSF m a [(b, Log Double)]
runPopulationS :: forall (m :: * -> *) a b.
Monad m =>
Int
-> (forall x. Population m x -> Population m x)
-> MSF (Population m) a b
-> MSF m a [(b, Log Double)]
runPopulationS Int
nParticles forall x. Population m x -> Population m x
resampler = forall (m :: * -> *) a b.
Monad m =>
(forall x. Population m x -> Population m x)
-> Population m (MSF (Population m) a b)
-> MSF m a [(b, Log Double)]
runPopulationsS forall x. Population m x -> Population m x
resampler forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (m :: * -> *). Monad m => Int -> Population m ()
spawn Int
nParticles forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$>)

-- | Run the Sequential Monte Carlo algorithm continuously on a 'Population' of 'MSF's
runPopulationsS ::
  Monad m =>
  -- | Resampler
  (forall x. Population m x -> Population m x) ->
  Population m (MSF (Population m) a b) ->
  MSF m a [(b, Log Double)]
runPopulationsS :: forall (m :: * -> *) a b.
Monad m =>
(forall x. Population m x -> Population m x)
-> Population m (MSF (Population m) a b)
-> MSF m a [(b, Log Double)]
runPopulationsS forall x. Population m x -> Population m x
resampler = Population m (MSF (Population m) a b) -> MSF m a [(b, Log Double)]
go
  where
    go :: Population m (MSF (Population m) a b) -> MSF m a [(b, Log Double)]
go Population m (MSF (Population m) a b)
msfs = forall (m :: * -> *) a b. (a -> m (b, MSF m a b)) -> MSF m a b
MSF forall a b. (a -> b) -> a -> b
$ \a
a -> do
      -- TODO This is quite different than the dunai version now. Maybe it's right nevertheless.
      -- FIXME This normalizes, which introduces bias, whatever that means
      [((b, MSF (Population m) a b), Log Double)]
bAndMSFs <- forall (m :: * -> *) a. Population m a -> m [(a, Log Double)]
runPopulation forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => Population m a -> Population m a
normalize forall a b. (a -> b) -> a -> b
$ forall x. Population m x -> Population m x
resampler forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) a b. MSF m a b -> a -> m (b, MSF m a b)
unMSF a
a forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Population m (MSF (Population m) a b)
msfs
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
        forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Population m (MSF (Population m) a b) -> MSF m a [(b, Log Double)]
go forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> Population m a
fromWeightedList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return) forall a b. (a -> b) -> a -> b
$
          forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
            (forall a b. (a, b) -> (b, a)
swap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& forall a b. (a, b) -> (b, a)
swap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> (b, a)
swap forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [((b, MSF (Population m) a b), Log Double)]
bAndMSFs

-- FIXME see PR re-adding this to monad-bayes
normalize :: Monad m => Population m a -> Population m a
normalize :: forall (m :: * -> *) a. Monad m => Population m a -> Population m a
normalize = forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> Population m a
fromWeightedList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\[(a, Log Double)]
particles -> forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (forall a. Fractional a => a -> a -> a
/ (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(a, Log Double)]
particles)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(a, Log Double)]
particles) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Population m a -> m [(a, Log Double)]
runPopulation