module Control.Monad.Bayes.Traced.Basic
( Traced,
hoistT,
marginal,
mhStep,
mh,
)
where
import Control.Applicative (liftA2)
import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Free (FreeSampler)
import Control.Monad.Bayes.Traced.Common
import Control.Monad.Bayes.Weighted (Weighted)
import Data.Functor.Identity (Identity)
data Traced m a
= Traced
{
Traced m a -> Weighted (FreeSampler Identity) a
model :: Weighted (FreeSampler Identity) a,
Traced m a -> m (Trace a)
traceDist :: m (Trace a)
}
instance Monad m => Functor (Traced m) where
fmap :: (a -> b) -> Traced m a -> Traced m b
fmap f :: a -> b
f (Traced m :: Weighted (FreeSampler Identity) a
m d :: m (Trace a)
d) = Weighted (FreeSampler Identity) b -> m (Trace b) -> Traced m b
forall (m :: * -> *) a.
Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
Traced ((a -> b)
-> Weighted (FreeSampler Identity) a
-> Weighted (FreeSampler Identity) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Weighted (FreeSampler Identity) a
m) ((Trace a -> Trace b) -> m (Trace a) -> m (Trace b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> Trace a -> Trace b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) m (Trace a)
d)
instance Monad m => Applicative (Traced m) where
pure :: a -> Traced m a
pure x :: a
x = Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
forall (m :: * -> *) a.
Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
Traced (a -> Weighted (FreeSampler Identity) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x) (Trace a -> m (Trace a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Trace a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x))
(Traced mf :: Weighted (FreeSampler Identity) (a -> b)
mf df :: m (Trace (a -> b))
df) <*> :: Traced m (a -> b) -> Traced m a -> Traced m b
<*> (Traced mx :: Weighted (FreeSampler Identity) a
mx dx :: m (Trace a)
dx) = Weighted (FreeSampler Identity) b -> m (Trace b) -> Traced m b
forall (m :: * -> *) a.
Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
Traced (Weighted (FreeSampler Identity) (a -> b)
mf Weighted (FreeSampler Identity) (a -> b)
-> Weighted (FreeSampler Identity) a
-> Weighted (FreeSampler Identity) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Weighted (FreeSampler Identity) a
mx) ((Trace (a -> b) -> Trace a -> Trace b)
-> m (Trace (a -> b)) -> m (Trace a) -> m (Trace b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Trace (a -> b) -> Trace a -> Trace b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
(<*>) m (Trace (a -> b))
df m (Trace a)
dx)
instance Monad m => Monad (Traced m) where
(Traced mx :: Weighted (FreeSampler Identity) a
mx dx :: m (Trace a)
dx) >>= :: Traced m a -> (a -> Traced m b) -> Traced m b
>>= f :: a -> Traced m b
f = Weighted (FreeSampler Identity) b -> m (Trace b) -> Traced m b
forall (m :: * -> *) a.
Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
Traced Weighted (FreeSampler Identity) b
my m (Trace b)
dy
where
my :: Weighted (FreeSampler Identity) b
my = Weighted (FreeSampler Identity) a
mx Weighted (FreeSampler Identity) a
-> (a -> Weighted (FreeSampler Identity) b)
-> Weighted (FreeSampler Identity) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Traced m b -> Weighted (FreeSampler Identity) b
forall (m :: * -> *) a.
Traced m a -> Weighted (FreeSampler Identity) a
model (Traced m b -> Weighted (FreeSampler Identity) b)
-> (a -> Traced m b) -> a -> Weighted (FreeSampler Identity) b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Traced m b
f
dy :: m (Trace b)
dy = m (Trace a)
dx m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
forall (m :: * -> *) a b.
Monad m =>
m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
`bind` (Traced m b -> m (Trace b)
forall (m :: * -> *) a. Traced m a -> m (Trace a)
traceDist (Traced m b -> m (Trace b))
-> (a -> Traced m b) -> a -> m (Trace b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Traced m b
f)
instance MonadSample m => MonadSample (Traced m) where
random :: Traced m Double
random = Weighted (FreeSampler Identity) Double
-> m (Trace Double) -> Traced m Double
forall (m :: * -> *) a.
Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
Traced Weighted (FreeSampler Identity) Double
forall (m :: * -> *). MonadSample m => m Double
random ((Double -> Trace Double) -> m Double -> m (Trace Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Double -> Trace Double
singleton m Double
forall (m :: * -> *). MonadSample m => m Double
random)
instance MonadCond m => MonadCond (Traced m) where
score :: Log Double -> Traced m ()
score w :: Log Double
w = Weighted (FreeSampler Identity) () -> m (Trace ()) -> Traced m ()
forall (m :: * -> *) a.
Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
Traced (Log Double -> Weighted (FreeSampler Identity) ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score Log Double
w) (Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score Log Double
w m () -> m (Trace ()) -> m (Trace ())
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Trace () -> m (Trace ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Log Double -> Trace ()
scored Log Double
w))
instance MonadInfer m => MonadInfer (Traced m)
hoistT :: (forall x. m x -> m x) -> Traced m a -> Traced m a
hoistT :: (forall x. m x -> m x) -> Traced m a -> Traced m a
hoistT f :: forall x. m x -> m x
f (Traced m :: Weighted (FreeSampler Identity) a
m d :: m (Trace a)
d) = Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
forall (m :: * -> *) a.
Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
Traced Weighted (FreeSampler Identity) a
m (m (Trace a) -> m (Trace a)
forall x. m x -> m x
f m (Trace a)
d)
marginal :: Monad m => Traced m a -> m a
marginal :: Traced m a -> m a
marginal (Traced _ d :: m (Trace a)
d) = (Trace a -> a) -> m (Trace a) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Trace a -> a
forall a. Trace a -> a
output m (Trace a)
d
mhStep :: MonadSample m => Traced m a -> Traced m a
mhStep :: Traced m a -> Traced m a
mhStep (Traced m :: Weighted (FreeSampler Identity) a
m d :: m (Trace a)
d) = Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
forall (m :: * -> *) a.
Weighted (FreeSampler Identity) a -> m (Trace a) -> Traced m a
Traced Weighted (FreeSampler Identity) a
m m (Trace a)
d'
where
d' :: m (Trace a)
d' = m (Trace a)
d m (Trace a) -> (Trace a -> m (Trace a)) -> m (Trace a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Weighted (FreeSampler Identity) a -> Trace a -> m (Trace a)
forall (m :: * -> *) a.
MonadSample m =>
Weighted (FreeSampler Identity) a -> Trace a -> m (Trace a)
mhTrans' Weighted (FreeSampler Identity) a
m
mh :: MonadSample m => Int -> Traced m a -> m [a]
mh :: Int -> Traced m a -> m [a]
mh n :: Int
n (Traced m :: Weighted (FreeSampler Identity) a
m d :: m (Trace a)
d) = ([Trace a] -> [a]) -> m [Trace a] -> m [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Trace a -> a) -> [Trace a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Trace a -> a
forall a. Trace a -> a
output) (Int -> m [Trace a]
f Int
n)
where
f :: Int -> m [Trace a]
f 0 = (Trace a -> [Trace a]) -> m (Trace a) -> m [Trace a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Trace a -> [Trace a] -> [Trace a]
forall a. a -> [a] -> [a]
: []) m (Trace a)
d
f k :: Int
k = do
~(x :: Trace a
x : xs :: [Trace a]
xs) <- Int -> m [Trace a]
f (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
-1)
Trace a
y <- Weighted (FreeSampler Identity) a -> Trace a -> m (Trace a)
forall (m :: * -> *) a.
MonadSample m =>
Weighted (FreeSampler Identity) a -> Trace a -> m (Trace a)
mhTrans' Weighted (FreeSampler Identity) a
m Trace a
x
[Trace a] -> m [Trace a]
forall (m :: * -> *) a. Monad m => a -> m a
return (Trace a
y Trace a -> [Trace a] -> [Trace a]
forall a. a -> [a] -> [a]
: Trace a
x Trace a -> [Trace a] -> [Trace a]
forall a. a -> [a] -> [a]
: [Trace a]
xs)