{-# LANGUAGE NamedFieldPuns #-}
module Data.Automaton.Bayes where
import Control.Arrow
import Control.Monad.Trans.Reader (ReaderT (..))
import Numeric.Log hiding (sum)
import Control.Monad.Bayes.Population (PopulationT (..), fromWeightedList, runPopulationT)
import Control.Monad.Morph (hoist)
import Data.Automaton (Automaton (..), handleAutomaton)
import Data.Stream (StreamT (..))
import Data.Stream.Result (Result (..))
runPopulationS ::
forall m a b.
(Monad m) =>
Int ->
(forall x. PopulationT m x -> PopulationT m x) ->
Automaton (PopulationT m) a b ->
Automaton m a [(b, Log Double)]
runPopulationS :: forall (m :: * -> *) a b.
Monad m =>
Int
-> (forall x. PopulationT m x -> PopulationT m x)
-> Automaton (PopulationT m) a b
-> Automaton m a [(b, Log Double)]
runPopulationS Int
nParticles forall x. PopulationT m x -> PopulationT m x
resampler =
(StreamT (ReaderT a (PopulationT m)) b
-> StreamT (ReaderT a m) [(b, Log Double)])
-> Automaton (PopulationT m) a b -> Automaton m a [(b, Log Double)]
forall (m :: * -> *) a b c (n :: * -> *) d.
Monad m =>
(StreamT (ReaderT a m) b -> StreamT (ReaderT c n) d)
-> Automaton m a b -> Automaton n c d
handleAutomaton
( (forall x.
PopulationT (ReaderT a m) x -> PopulationT (ReaderT a m) x)
-> StreamT (PopulationT (ReaderT a m)) b
-> StreamT (ReaderT a m) [(b, Log Double)]
forall (m :: * -> *) b.
Monad m =>
(forall x. PopulationT m x -> PopulationT m x)
-> StreamT (PopulationT m) b -> StreamT m [(b, Log Double)]
runPopulationStream
(ReaderT a (PopulationT m) x -> PopulationT (ReaderT a m) x
forall (m :: * -> *) r a.
Monad m =>
ReaderT r (PopulationT m) a -> PopulationT (ReaderT r m) a
commuteReaderPopulation (ReaderT a (PopulationT m) x -> PopulationT (ReaderT a m) x)
-> (PopulationT (ReaderT a m) x -> ReaderT a (PopulationT m) x)
-> PopulationT (ReaderT a m) x
-> PopulationT (ReaderT a m) x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. PopulationT m x -> PopulationT m x)
-> ReaderT a (PopulationT m) x -> ReaderT a (PopulationT m) x
forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
(b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
forall (m :: * -> *) (n :: * -> *) b.
Monad m =>
(forall a. m a -> n a) -> ReaderT a m b -> ReaderT a n b
hoist PopulationT m a -> PopulationT m a
forall x. PopulationT m x -> PopulationT m x
resampler (ReaderT a (PopulationT m) x -> ReaderT a (PopulationT m) x)
-> (PopulationT (ReaderT a m) x -> ReaderT a (PopulationT m) x)
-> PopulationT (ReaderT a m) x
-> ReaderT a (PopulationT m) x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT (ReaderT a m) x -> ReaderT a (PopulationT m) x
forall (m :: * -> *) r a.
Monad m =>
PopulationT (ReaderT r m) a -> ReaderT r (PopulationT m) a
commuteReaderPopulationBack)
(StreamT (PopulationT (ReaderT a m)) b
-> StreamT (ReaderT a m) [(b, Log Double)])
-> (StreamT (ReaderT a (PopulationT m)) b
-> StreamT (PopulationT (ReaderT a m)) b)
-> StreamT (ReaderT a (PopulationT m)) b
-> StreamT (ReaderT a m) [(b, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a.
ReaderT a (PopulationT m) a -> PopulationT (ReaderT a m) a)
-> StreamT (ReaderT a (PopulationT m)) b
-> StreamT (PopulationT (ReaderT a m)) b
forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
(b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
forall (m :: * -> *) (n :: * -> *) b.
Monad m =>
(forall a. m a -> n a) -> StreamT m b -> StreamT n b
hoist ReaderT a (PopulationT m) a -> PopulationT (ReaderT a m) a
forall a.
ReaderT a (PopulationT m) a -> PopulationT (ReaderT a m) a
forall (m :: * -> *) r a.
Monad m =>
ReaderT r (PopulationT m) a -> PopulationT (ReaderT r m) a
commuteReaderPopulation
)
where
commuteReaderPopulation :: forall m r a. (Monad m) => ReaderT r (PopulationT m) a -> PopulationT (ReaderT r m) a
commuteReaderPopulation :: forall (m :: * -> *) r a.
Monad m =>
ReaderT r (PopulationT m) a -> PopulationT (ReaderT r m) a
commuteReaderPopulation = ReaderT r m [(a, Log Double)] -> PopulationT (ReaderT r m) a
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> PopulationT m a
fromWeightedList (ReaderT r m [(a, Log Double)] -> PopulationT (ReaderT r m) a)
-> (ReaderT r (PopulationT m) a -> ReaderT r m [(a, Log Double)])
-> ReaderT r (PopulationT m) a
-> PopulationT (ReaderT r m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (r -> m [(a, Log Double)]) -> ReaderT r m [(a, Log Double)]
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m [(a, Log Double)]) -> ReaderT r m [(a, Log Double)])
-> (ReaderT r (PopulationT m) a -> r -> m [(a, Log Double)])
-> ReaderT r (PopulationT m) a
-> ReaderT r m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PopulationT m a -> m [(a, Log Double)])
-> (r -> PopulationT m a) -> r -> m [(a, Log Double)]
forall a b. (a -> b) -> (r -> a) -> r -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PopulationT m a -> m [(a, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT ((r -> PopulationT m a) -> r -> m [(a, Log Double)])
-> (ReaderT r (PopulationT m) a -> r -> PopulationT m a)
-> ReaderT r (PopulationT m) a
-> r
-> m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReaderT r (PopulationT m) a -> r -> PopulationT m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT
commuteReaderPopulationBack :: forall m r a. (Monad m) => PopulationT (ReaderT r m) a -> ReaderT r (PopulationT m) a
commuteReaderPopulationBack :: forall (m :: * -> *) r a.
Monad m =>
PopulationT (ReaderT r m) a -> ReaderT r (PopulationT m) a
commuteReaderPopulationBack = (r -> PopulationT m a) -> ReaderT r (PopulationT m) a
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> PopulationT m a) -> ReaderT r (PopulationT m) a)
-> (PopulationT (ReaderT r m) a -> r -> PopulationT m a)
-> PopulationT (ReaderT r m) a
-> ReaderT r (PopulationT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (m [(a, Log Double)] -> PopulationT m a)
-> (r -> m [(a, Log Double)]) -> r -> PopulationT m a
forall a b. (a -> b) -> (r -> a) -> r -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap m [(a, Log Double)] -> PopulationT m a
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> PopulationT m a
fromWeightedList ((r -> m [(a, Log Double)]) -> r -> PopulationT m a)
-> (PopulationT (ReaderT r m) a -> r -> m [(a, Log Double)])
-> PopulationT (ReaderT r m) a
-> r
-> PopulationT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReaderT r m [(a, Log Double)] -> r -> m [(a, Log Double)]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT r m [(a, Log Double)] -> r -> m [(a, Log Double)])
-> (PopulationT (ReaderT r m) a -> ReaderT r m [(a, Log Double)])
-> PopulationT (ReaderT r m) a
-> r
-> m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT (ReaderT r m) a -> ReaderT r m [(a, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT
runPopulationStream ::
forall m b.
(Monad m) =>
(forall x. PopulationT m x -> PopulationT m x) ->
StreamT (PopulationT m) b ->
StreamT m [(b, Log Double)]
runPopulationStream :: forall (m :: * -> *) b.
Monad m =>
(forall x. PopulationT m x -> PopulationT m x)
-> StreamT (PopulationT m) b -> StreamT m [(b, Log Double)]
runPopulationStream forall x. PopulationT m x -> PopulationT m x
resampler StreamT {s -> PopulationT m (Result s b)
step :: s -> PopulationT m (Result s b)
step :: ()
step, s
state :: s
state :: ()
state} =
StreamT
{ state :: [(s, Log Double)]
state = Int -> (s, Log Double) -> [(s, Log Double)]
forall a. Int -> a -> [a]
replicate Int
nParticles (s
state, Log Double
1 Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nParticles)
, step :: [(s, Log Double)] -> m (Result [(s, Log Double)] [(b, Log Double)])
step = \[(s, Log Double)]
states -> do
[(Result s b, Log Double)]
resultsAndProbabilities <- PopulationT m (Result s b) -> m [(Result s b, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT (PopulationT m (Result s b) -> m [(Result s b, Log Double)])
-> PopulationT m (Result s b) -> m [(Result s b, Log Double)]
forall a b. (a -> b) -> a -> b
$ PopulationT m (Result s b) -> PopulationT m (Result s b)
forall (m :: * -> *) a.
Monad m =>
PopulationT m a -> PopulationT m a
normalize (PopulationT m (Result s b) -> PopulationT m (Result s b))
-> PopulationT m (Result s b) -> PopulationT m (Result s b)
forall a b. (a -> b) -> a -> b
$ PopulationT m (Result s b) -> PopulationT m (Result s b)
forall x. PopulationT m x -> PopulationT m x
resampler (PopulationT m (Result s b) -> PopulationT m (Result s b))
-> PopulationT m (Result s b) -> PopulationT m (Result s b)
forall a b. (a -> b) -> a -> b
$ do
s
state <- m [(s, Log Double)] -> PopulationT m s
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> PopulationT m a
fromWeightedList (m [(s, Log Double)] -> PopulationT m s)
-> m [(s, Log Double)] -> PopulationT m s
forall a b. (a -> b) -> a -> b
$ [(s, Log Double)] -> m [(s, Log Double)]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(s, Log Double)]
states
s -> PopulationT m (Result s b)
step s
state
Result [(s, Log Double)] [(b, Log Double)]
-> m (Result [(s, Log Double)] [(b, Log Double)])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Result [(s, Log Double)] [(b, Log Double)]
-> m (Result [(s, Log Double)] [(b, Log Double)]))
-> Result [(s, Log Double)] [(b, Log Double)]
-> m (Result [(s, Log Double)] [(b, Log Double)])
forall a b. (a -> b) -> a -> b
$! [(s, Log Double)]
-> [(b, Log Double)] -> Result [(s, Log Double)] [(b, Log Double)]
forall s a. s -> a -> Result s a
Result ((Result s b -> s) -> (Result s b, Log Double) -> (s, Log Double)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Result s b -> s
forall s a. Result s a -> s
resultState ((Result s b, Log Double) -> (s, Log Double))
-> [(Result s b, Log Double)] -> [(s, Log Double)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Result s b, Log Double)]
resultsAndProbabilities) ((Result s b -> b) -> (Result s b, Log Double) -> (b, Log Double)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Result s b -> b
forall s a. Result s a -> a
output ((Result s b, Log Double) -> (b, Log Double))
-> [(Result s b, Log Double)] -> [(b, Log Double)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Result s b, Log Double)]
resultsAndProbabilities)
}
normalize :: (Monad m) => PopulationT m a -> PopulationT m a
normalize :: forall (m :: * -> *) a.
Monad m =>
PopulationT m a -> PopulationT m a
normalize = m [(a, Log Double)] -> PopulationT m a
forall (m :: * -> *) a.
Monad m =>
m [(a, Log Double)] -> PopulationT m a
fromWeightedList (m [(a, Log Double)] -> PopulationT m a)
-> (PopulationT m a -> m [(a, Log Double)])
-> PopulationT m a
-> PopulationT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(a, Log Double)] -> [(a, Log Double)])
-> m [(a, Log Double)] -> m [(a, Log Double)]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\[(a, Log Double)]
particles -> (Log Double -> Log Double) -> (a, Log Double) -> (a, Log Double)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ ([Log Double] -> Log Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Log Double] -> Log Double) -> [Log Double] -> Log Double
forall a b. (a -> b) -> a -> b
$ (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd ((a, Log Double) -> Log Double)
-> [(a, Log Double)] -> [Log Double]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(a, Log Double)]
particles)) ((a, Log Double) -> (a, Log Double))
-> [(a, Log Double)] -> [(a, Log Double)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(a, Log Double)]
particles) (m [(a, Log Double)] -> m [(a, Log Double)])
-> (PopulationT m a -> m [(a, Log Double)])
-> PopulationT m a
-> m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT m a -> m [(a, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT