{-# LANGUAGE NamedFieldPuns #-}

module Data.Automaton.Bayes where

-- base
import Control.Arrow

-- transformers
import Control.Monad.Trans.Reader (ReaderT (..))

-- log-domain
import Numeric.Log hiding (sum)

-- monad-bayes
import Control.Monad.Bayes.Population (PopulationT (..), fromWeightedList, runPopulationT)

-- mmorph
import Control.Monad.Morph (hoist)

-- automaton
import Data.Automaton (Automaton (..), handleAutomaton)
import Data.Stream (StreamT (..))
import Data.Stream.Result (Result (..))

-- | Run the Sequential Monte Carlo algorithm continuously on an 'Automaton'
runPopulationS ::
  forall m a b.
  (Monad m) =>
  -- | Number of particles
  Int ->
  -- | Resampler
  (forall x. PopulationT m x -> PopulationT m x) ->
  Automaton (PopulationT m) a b ->
  -- FIXME Why not Automaton m a (PopulationT b)
  Automaton m a [(b, Log Double)]
runPopulationS :: forall (m :: * -> *) a b.
Monad m =>
Int
-> (forall x. PopulationT m x -> PopulationT m x)
-> Automaton (PopulationT m) a b
-> Automaton m a [(b, Log Double)]
runPopulationS Int
nParticles forall x. PopulationT m x -> PopulationT m x
resampler =
  (StreamT (ReaderT a (PopulationT m)) b
 -> StreamT (ReaderT a m) [(b, Log Double)])
-> Automaton (PopulationT m) a b -> Automaton m a [(b, Log Double)]
forall (m :: * -> *) a b c (n :: * -> *) d.
Monad m =>
(StreamT (ReaderT a m) b -> StreamT (ReaderT c n) d)
-> Automaton m a b -> Automaton n c d
handleAutomaton
    ( (forall x.
 PopulationT (ReaderT a m) x -> PopulationT (ReaderT a m) x)
-> StreamT (PopulationT (ReaderT a m)) b
-> StreamT (ReaderT a m) [(b, Log Double)]
forall (m :: * -> *) b.
Monad m =>
(forall x. PopulationT m x -> PopulationT m x)
-> StreamT (PopulationT m) b -> StreamT m [(b, Log Double)]
runPopulationStream
        (ReaderT a (PopulationT m) x -> PopulationT (ReaderT a m) x
forall (m :: * -> *) r a.
Monad m =>
ReaderT r (PopulationT m) a -> PopulationT (ReaderT r m) a
commuteReaderPopulation (ReaderT a (PopulationT m) x -> PopulationT (ReaderT a m) x)
-> (PopulationT (ReaderT a m) x -> ReaderT a (PopulationT m) x)
-> PopulationT (ReaderT a m) x
-> PopulationT (ReaderT a m) x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. PopulationT m x -> PopulationT m x)
-> ReaderT a (PopulationT m) x -> ReaderT a (PopulationT m) x
forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
forall (m :: * -> *) (n :: * -> *) b.
Monad m =>
(forall a. m a -> n a) -> ReaderT a m b -> ReaderT a n b
hoist PopulationT m a -> PopulationT m a
forall x. PopulationT m x -> PopulationT m x
resampler (ReaderT a (PopulationT m) x -> ReaderT a (PopulationT m) x)
-> (PopulationT (ReaderT a m) x -> ReaderT a (PopulationT m) x)
-> PopulationT (ReaderT a m) x
-> ReaderT a (PopulationT m) x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT (ReaderT a m) x -> ReaderT a (PopulationT m) x
forall (m :: * -> *) r a.
Monad m =>
PopulationT (ReaderT r m) a -> ReaderT r (PopulationT m) a
commuteReaderPopulationBack)
        (StreamT (PopulationT (ReaderT a m)) b
 -> StreamT (ReaderT a m) [(b, Log Double)])
-> (StreamT (ReaderT a (PopulationT m)) b
    -> StreamT (PopulationT (ReaderT a m)) b)
-> StreamT (ReaderT a (PopulationT m)) b
-> StreamT (ReaderT a m) [(b, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a.
 ReaderT a (PopulationT m) a -> PopulationT (ReaderT a m) a)
-> StreamT (ReaderT a (PopulationT m)) b
-> StreamT (PopulationT (ReaderT a m)) b
forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
forall (m :: * -> *) (n :: * -> *) b.
Monad m =>
(forall a. m a -> n a) -> StreamT m b -> StreamT n b
hoist ReaderT a (PopulationT m) a -> PopulationT (ReaderT a m) a
forall a.
ReaderT a (PopulationT m) a -> PopulationT (ReaderT a m) a
forall (m :: * -> *) r a.
Monad m =>
ReaderT r (PopulationT m) a -> PopulationT (ReaderT r m) a
commuteReaderPopulation
    )
  where
    commuteReaderPopulation :: forall m r a. (Monad m) => ReaderT r (PopulationT m) a -> PopulationT (ReaderT r m) a
    commuteReaderPopulation :: forall (m :: * -> *) r a.
Monad m =>
ReaderT r (PopulationT m) a -> PopulationT (ReaderT r m) a
commuteReaderPopulation = ReaderT r m [(a, Log Double)] -> PopulationT (ReaderT r m) a
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> PopulationT m a
fromWeightedList (ReaderT r m [(a, Log Double)] -> PopulationT (ReaderT r m) a)
-> (ReaderT r (PopulationT m) a -> ReaderT r m [(a, Log Double)])
-> ReaderT r (PopulationT m) a
-> PopulationT (ReaderT r m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (r -> m [(a, Log Double)]) -> ReaderT r m [(a, Log Double)]
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m [(a, Log Double)]) -> ReaderT r m [(a, Log Double)])
-> (ReaderT r (PopulationT m) a -> r -> m [(a, Log Double)])
-> ReaderT r (PopulationT m) a
-> ReaderT r m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PopulationT m a -> m [(a, Log Double)])
-> (r -> PopulationT m a) -> r -> m [(a, Log Double)]
forall a b. (a -> b) -> (r -> a) -> r -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PopulationT m a -> m [(a, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT ((r -> PopulationT m a) -> r -> m [(a, Log Double)])
-> (ReaderT r (PopulationT m) a -> r -> PopulationT m a)
-> ReaderT r (PopulationT m) a
-> r
-> m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReaderT r (PopulationT m) a -> r -> PopulationT m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT

    commuteReaderPopulationBack :: forall m r a. (Monad m) => PopulationT (ReaderT r m) a -> ReaderT r (PopulationT m) a
    commuteReaderPopulationBack :: forall (m :: * -> *) r a.
Monad m =>
PopulationT (ReaderT r m) a -> ReaderT r (PopulationT m) a
commuteReaderPopulationBack = (r -> PopulationT m a) -> ReaderT r (PopulationT m) a
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> PopulationT m a) -> ReaderT r (PopulationT m) a)
-> (PopulationT (ReaderT r m) a -> r -> PopulationT m a)
-> PopulationT (ReaderT r m) a
-> ReaderT r (PopulationT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (m [(a, Log Double)] -> PopulationT m a)
-> (r -> m [(a, Log Double)]) -> r -> PopulationT m a
forall a b. (a -> b) -> (r -> a) -> r -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap m [(a, Log Double)] -> PopulationT m a
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> PopulationT m a
fromWeightedList ((r -> m [(a, Log Double)]) -> r -> PopulationT m a)
-> (PopulationT (ReaderT r m) a -> r -> m [(a, Log Double)])
-> PopulationT (ReaderT r m) a
-> r
-> PopulationT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReaderT r m [(a, Log Double)] -> r -> m [(a, Log Double)]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT r m [(a, Log Double)] -> r -> m [(a, Log Double)])
-> (PopulationT (ReaderT r m) a -> ReaderT r m [(a, Log Double)])
-> PopulationT (ReaderT r m) a
-> r
-> m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT (ReaderT r m) a -> ReaderT r m [(a, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT

    runPopulationStream ::
      forall m b.
      (Monad m) =>
      (forall x. PopulationT m x -> PopulationT m x) ->
      StreamT (PopulationT m) b ->
      StreamT m [(b, Log Double)]
    runPopulationStream :: forall (m :: * -> *) b.
Monad m =>
(forall x. PopulationT m x -> PopulationT m x)
-> StreamT (PopulationT m) b -> StreamT m [(b, Log Double)]
runPopulationStream forall x. PopulationT m x -> PopulationT m x
resampler StreamT {s -> PopulationT m (Result s b)
step :: s -> PopulationT m (Result s b)
step :: ()
step, s
state :: s
state :: ()
state} =
      StreamT
        { state :: [(s, Log Double)]
state = Int -> (s, Log Double) -> [(s, Log Double)]
forall a. Int -> a -> [a]
replicate Int
nParticles (s
state, Log Double
1 Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nParticles)
        , step :: [(s, Log Double)] -> m (Result [(s, Log Double)] [(b, Log Double)])
step = \[(s, Log Double)]
states -> do
            [(Result s b, Log Double)]
resultsAndProbabilities <- PopulationT m (Result s b) -> m [(Result s b, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT (PopulationT m (Result s b) -> m [(Result s b, Log Double)])
-> PopulationT m (Result s b) -> m [(Result s b, Log Double)]
forall a b. (a -> b) -> a -> b
$ PopulationT m (Result s b) -> PopulationT m (Result s b)
forall (m :: * -> *) a.
Monad m =>
PopulationT m a -> PopulationT m a
normalize (PopulationT m (Result s b) -> PopulationT m (Result s b))
-> PopulationT m (Result s b) -> PopulationT m (Result s b)
forall a b. (a -> b) -> a -> b
$ PopulationT m (Result s b) -> PopulationT m (Result s b)
forall x. PopulationT m x -> PopulationT m x
resampler (PopulationT m (Result s b) -> PopulationT m (Result s b))
-> PopulationT m (Result s b) -> PopulationT m (Result s b)
forall a b. (a -> b) -> a -> b
$ do
              s
state <- m [(s, Log Double)] -> PopulationT m s
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> PopulationT m a
fromWeightedList (m [(s, Log Double)] -> PopulationT m s)
-> m [(s, Log Double)] -> PopulationT m s
forall a b. (a -> b) -> a -> b
$ [(s, Log Double)] -> m [(s, Log Double)]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(s, Log Double)]
states
              s -> PopulationT m (Result s b)
step s
state
            Result [(s, Log Double)] [(b, Log Double)]
-> m (Result [(s, Log Double)] [(b, Log Double)])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Result [(s, Log Double)] [(b, Log Double)]
 -> m (Result [(s, Log Double)] [(b, Log Double)]))
-> Result [(s, Log Double)] [(b, Log Double)]
-> m (Result [(s, Log Double)] [(b, Log Double)])
forall a b. (a -> b) -> a -> b
$! [(s, Log Double)]
-> [(b, Log Double)] -> Result [(s, Log Double)] [(b, Log Double)]
forall s a. s -> a -> Result s a
Result ((Result s b -> s) -> (Result s b, Log Double) -> (s, Log Double)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Result s b -> s
forall s a. Result s a -> s
resultState ((Result s b, Log Double) -> (s, Log Double))
-> [(Result s b, Log Double)] -> [(s, Log Double)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Result s b, Log Double)]
resultsAndProbabilities) ((Result s b -> b) -> (Result s b, Log Double) -> (b, Log Double)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Result s b -> b
forall s a. Result s a -> a
output ((Result s b, Log Double) -> (b, Log Double))
-> [(Result s b, Log Double)] -> [(b, Log Double)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Result s b, Log Double)]
resultsAndProbabilities)
        }

-- FIXME see PR re-adding this to monad-bayes
normalize :: (Monad m) => PopulationT m a -> PopulationT m a
normalize :: forall (m :: * -> *) a.
Monad m =>
PopulationT m a -> PopulationT m a
normalize = m [(a, Log Double)] -> PopulationT m a
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> PopulationT m a
fromWeightedList (m [(a, Log Double)] -> PopulationT m a)
-> (PopulationT m a -> m [(a, Log Double)])
-> PopulationT m a
-> PopulationT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(a, Log Double)] -> [(a, Log Double)])
-> m [(a, Log Double)] -> m [(a, Log Double)]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\[(a, Log Double)]
particles -> (Log Double -> Log Double) -> (a, Log Double) -> (a, Log Double)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ ([Log Double] -> Log Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Log Double] -> Log Double) -> [Log Double] -> Log Double
forall a b. (a -> b) -> a -> b
$ (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd ((a, Log Double) -> Log Double)
-> [(a, Log Double)] -> [Log Double]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(a, Log Double)]
particles)) ((a, Log Double) -> (a, Log Double))
-> [(a, Log Double)] -> [(a, Log Double)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(a, Log Double)]
particles) (m [(a, Log Double)] -> m [(a, Log Double)])
-> (PopulationT m a -> m [(a, Log Double)])
-> PopulationT m a
-> m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT m a -> m [(a, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT