{-# LANGUAGE RankNTypes #-}
module Control.Monad.Bayes.Traced.Dynamic
( Traced,
hoist,
marginal,
freeze,
mhStep,
mh,
)
where
import Control.Monad (join)
import Control.Monad.Bayes.Class
( MonadDistribution (random),
MonadFactor (..),
MonadMeasure,
)
import Control.Monad.Bayes.Density.Free (Density)
import Control.Monad.Bayes.Traced.Common
( Trace (..),
bind,
mhTransFree,
scored,
singleton,
)
import Control.Monad.Bayes.Weighted (Weighted)
import Control.Monad.Trans (MonadTrans (..))
import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList)
newtype Traced m a = Traced {forall (m :: * -> *) a.
Traced m a -> m (Weighted (Density m) a, Trace a)
runTraced :: m (Weighted (Density m) a, Trace a)}
pushM :: Monad m => m (Weighted (Density m) a) -> Weighted (Density m) a
pushM :: forall (m :: * -> *) a.
Monad m =>
m (Weighted (Density m) a) -> Weighted (Density m) a
pushM = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
instance Monad m => Functor (Traced m) where
fmap :: forall a b. (a -> b) -> Traced m a -> Traced m b
fmap a -> b
f (Traced m (Weighted (Density m) a, Trace a)
c) = forall (m :: * -> *) a.
m (Weighted (Density m) a, Trace a) -> Traced m a
Traced forall a b. (a -> b) -> a -> b
$ do
(Weighted (Density m) a
m, Trace a
t) <- m (Weighted (Density m) a, Trace a)
c
let m' :: Weighted (Density m) b
m' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Weighted (Density m) a
m
let t' :: Trace b
t' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Trace a
t
forall (m :: * -> *) a. Monad m => a -> m a
return (Weighted (Density m) b
m', Trace b
t')
instance Monad m => Applicative (Traced m) where
pure :: forall a. a -> Traced m a
pure a
x = forall (m :: * -> *) a.
m (Weighted (Density m) a, Trace a) -> Traced m a
Traced forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x, forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x)
(Traced m (Weighted (Density m) (a -> b), Trace (a -> b))
cf) <*> :: forall a b. Traced m (a -> b) -> Traced m a -> Traced m b
<*> (Traced m (Weighted (Density m) a, Trace a)
cx) = forall (m :: * -> *) a.
m (Weighted (Density m) a, Trace a) -> Traced m a
Traced forall a b. (a -> b) -> a -> b
$ do
(Weighted (Density m) (a -> b)
mf, Trace (a -> b)
tf) <- m (Weighted (Density m) (a -> b), Trace (a -> b))
cf
(Weighted (Density m) a
mx, Trace a
tx) <- m (Weighted (Density m) a, Trace a)
cx
forall (m :: * -> *) a. Monad m => a -> m a
return (Weighted (Density m) (a -> b)
mf forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Weighted (Density m) a
mx, Trace (a -> b)
tf forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Trace a
tx)
instance Monad m => Monad (Traced m) where
(Traced m (Weighted (Density m) a, Trace a)
cx) >>= :: forall a b. Traced m a -> (a -> Traced m b) -> Traced m b
>>= a -> Traced m b
f = forall (m :: * -> *) a.
m (Weighted (Density m) a, Trace a) -> Traced m a
Traced forall a b. (a -> b) -> a -> b
$ do
(Weighted (Density m) a
mx, Trace a
tx) <- m (Weighted (Density m) a, Trace a)
cx
let m :: Weighted (Density m) b
m = Weighted (Density m) a
mx forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a.
Monad m =>
m (Weighted (Density m) a) -> Weighted (Density m) a
pushM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
Traced m a -> m (Weighted (Density m) a, Trace a)
runTraced forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Traced m b
f
Trace b
t <- forall (m :: * -> *) a. Monad m => a -> m a
return Trace a
tx forall (m :: * -> *) a b.
Monad m =>
m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
`bind` (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
Traced m a -> m (Weighted (Density m) a, Trace a)
runTraced forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Traced m b
f)
forall (m :: * -> *) a. Monad m => a -> m a
return (Weighted (Density m) b
m, Trace b
t)
instance MonadTrans Traced where
lift :: forall (m :: * -> *) a. Monad m => m a -> Traced m a
lift m a
m = forall (m :: * -> *) a.
m (Weighted (Density m) a, Trace a) -> Traced m a
Traced forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((,) (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m a
m) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure) m a
m
instance MonadDistribution m => MonadDistribution (Traced m) where
random :: Traced m Double
random = forall (m :: * -> *) a.
m (Weighted (Density m) a, Trace a) -> Traced m a
Traced forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((,) forall (m :: * -> *). MonadDistribution m => m Double
random forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Trace Double
singleton) forall (m :: * -> *). MonadDistribution m => m Double
random
instance MonadFactor m => MonadFactor (Traced m) where
score :: Log Double -> Traced m ()
score Log Double
w = forall (m :: * -> *) a.
m (Weighted (Density m) a, Trace a) -> Traced m a
Traced forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score Log Double
w,) (forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score Log Double
w forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Log Double -> Trace ()
scored Log Double
w))
instance MonadMeasure m => MonadMeasure (Traced m)
hoist :: (forall x. m x -> m x) -> Traced m a -> Traced m a
hoist :: forall (m :: * -> *) a.
(forall x. m x -> m x) -> Traced m a -> Traced m a
hoist forall x. m x -> m x
f (Traced m (Weighted (Density m) a, Trace a)
c) = forall (m :: * -> *) a.
m (Weighted (Density m) a, Trace a) -> Traced m a
Traced (forall x. m x -> m x
f m (Weighted (Density m) a, Trace a)
c)
marginal :: Monad m => Traced m a -> m a
marginal :: forall (m :: * -> *) a. Monad m => Traced m a -> m a
marginal (Traced m (Weighted (Density m) a, Trace a)
c) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Trace a -> a
output forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) m (Weighted (Density m) a, Trace a)
c
freeze :: Monad m => Traced m a -> Traced m a
freeze :: forall (m :: * -> *) a. Monad m => Traced m a -> Traced m a
freeze (Traced m (Weighted (Density m) a, Trace a)
c) = forall (m :: * -> *) a.
m (Weighted (Density m) a, Trace a) -> Traced m a
Traced forall a b. (a -> b) -> a -> b
$ do
(Weighted (Density m) a
_, Trace a
t) <- m (Weighted (Density m) a, Trace a)
c
let x :: a
x = forall a. Trace a -> a
output Trace a
t
forall (m :: * -> *) a. Monad m => a -> m a
return (forall (m :: * -> *) a. Monad m => a -> m a
return a
x, forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x)
mhStep :: MonadDistribution m => Traced m a -> Traced m a
mhStep :: forall (m :: * -> *) a.
MonadDistribution m =>
Traced m a -> Traced m a
mhStep (Traced m (Weighted (Density m) a, Trace a)
c) = forall (m :: * -> *) a.
m (Weighted (Density m) a, Trace a) -> Traced m a
Traced forall a b. (a -> b) -> a -> b
$ do
(Weighted (Density m) a
m, Trace a
t) <- m (Weighted (Density m) a, Trace a)
c
Trace a
t' <- forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (Trace a)
mhTransFree Weighted (Density m) a
m Trace a
t
forall (m :: * -> *) a. Monad m => a -> m a
return (Weighted (Density m) a
m, Trace a
t')
mh :: MonadDistribution m => Int -> Traced m a -> m [a]
mh :: forall (m :: * -> *) a.
MonadDistribution m =>
Int -> Traced m a -> m [a]
mh Int
n (Traced m (Weighted (Density m) a, Trace a)
c) = do
(Weighted (Density m) a
m, Trace a
t) <- m (Weighted (Density m) a, Trace a)
c
let f :: t -> m (NonEmpty (Trace a))
f t
k
| t
k forall a. Ord a => a -> a -> Bool
<= t
0 = forall (m :: * -> *) a. Monad m => a -> m a
return (Trace a
t forall a. a -> [a] -> NonEmpty a
:| [])
| Bool
otherwise = do
(Trace a
x :| [Trace a]
xs) <- t -> m (NonEmpty (Trace a))
f (t
k forall a. Num a => a -> a -> a
- t
1)
Trace a
y <- forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (Trace a)
mhTransFree Weighted (Density m) a
m Trace a
x
forall (m :: * -> *) a. Monad m => a -> m a
return (Trace a
y forall a. a -> [a] -> NonEmpty a
:| Trace a
x forall a. a -> [a] -> [a]
: [Trace a]
xs)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (a -> b) -> [a] -> [b]
map forall a. Trace a -> a
output forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. NonEmpty a -> [a]
NE.toList) (forall {t}. (Ord t, Num t) => t -> m (NonEmpty (Trace a))
f Int
n)