module Control.Monad.Bayes.Traced.Common
( Trace,
singleton,
output,
scored,
bind,
mhTrans,
mhTrans',
)
where
import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Free as FreeSampler
import Control.Monad.Bayes.Weighted as Weighted
import Control.Monad.Trans.Writer
import Data.Functor.Identity
import Numeric.Log (Log, ln)
import Statistics.Distribution.DiscreteUniform (discreteUniformAB)
data Trace a
= Trace
{
Trace a -> [Double]
variables :: [Double],
Trace a -> a
output :: a,
Trace a -> Log Double
density :: Log Double
}
instance Functor Trace where
fmap :: (a -> b) -> Trace a -> Trace b
fmap f :: a -> b
f t :: Trace a
t = Trace a
t {output :: b
output = a -> b
f (Trace a -> a
forall a. Trace a -> a
output Trace a
t)}
instance Applicative Trace where
pure :: a -> Trace a
pure x :: a
x = Trace :: forall a. [Double] -> a -> Log Double -> Trace a
Trace {variables :: [Double]
variables = [], output :: a
output = a
x, density :: Log Double
density = 1}
tf :: Trace (a -> b)
tf <*> :: Trace (a -> b) -> Trace a -> Trace b
<*> tx :: Trace a
tx =
Trace :: forall a. [Double] -> a -> Log Double -> Trace a
Trace
{ variables :: [Double]
variables = Trace (a -> b) -> [Double]
forall a. Trace a -> [Double]
variables Trace (a -> b)
tf [Double] -> [Double] -> [Double]
forall a. [a] -> [a] -> [a]
++ Trace a -> [Double]
forall a. Trace a -> [Double]
variables Trace a
tx,
output :: b
output = Trace (a -> b) -> a -> b
forall a. Trace a -> a
output Trace (a -> b)
tf (Trace a -> a
forall a. Trace a -> a
output Trace a
tx),
density :: Log Double
density = Trace (a -> b) -> Log Double
forall a. Trace a -> Log Double
density Trace (a -> b)
tf Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Trace a -> Log Double
forall a. Trace a -> Log Double
density Trace a
tx
}
instance Monad Trace where
t :: Trace a
t >>= :: Trace a -> (a -> Trace b) -> Trace b
>>= f :: a -> Trace b
f =
let t' :: Trace b
t' = a -> Trace b
f (Trace a -> a
forall a. Trace a -> a
output Trace a
t)
in Trace b
t' {variables :: [Double]
variables = Trace a -> [Double]
forall a. Trace a -> [Double]
variables Trace a
t [Double] -> [Double] -> [Double]
forall a. [a] -> [a] -> [a]
++ Trace b -> [Double]
forall a. Trace a -> [Double]
variables Trace b
t', density :: Log Double
density = Trace a -> Log Double
forall a. Trace a -> Log Double
density Trace a
t Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Trace b -> Log Double
forall a. Trace a -> Log Double
density Trace b
t'}
singleton :: Double -> Trace Double
singleton :: Double -> Trace Double
singleton u :: Double
u = Trace :: forall a. [Double] -> a -> Log Double -> Trace a
Trace {variables :: [Double]
variables = [Double
u], output :: Double
output = Double
u, density :: Log Double
density = 1}
scored :: Log Double -> Trace ()
scored :: Log Double -> Trace ()
scored w :: Log Double
w = Trace :: forall a. [Double] -> a -> Log Double -> Trace a
Trace {variables :: [Double]
variables = [], output :: ()
output = (), density :: Log Double
density = Log Double
w}
bind :: Monad m => m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
bind :: m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
bind dx :: m (Trace a)
dx f :: a -> m (Trace b)
f = do
Trace a
t1 <- m (Trace a)
dx
Trace b
t2 <- a -> m (Trace b)
f (Trace a -> a
forall a. Trace a -> a
output Trace a
t1)
Trace b -> m (Trace b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Trace b -> m (Trace b)) -> Trace b -> m (Trace b)
forall a b. (a -> b) -> a -> b
$ Trace b
t2 {variables :: [Double]
variables = Trace a -> [Double]
forall a. Trace a -> [Double]
variables Trace a
t1 [Double] -> [Double] -> [Double]
forall a. [a] -> [a] -> [a]
++ Trace b -> [Double]
forall a. Trace a -> [Double]
variables Trace b
t2, density :: Log Double
density = Trace a -> Log Double
forall a. Trace a -> Log Double
density Trace a
t1 Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Trace b -> Log Double
forall a. Trace a -> Log Double
density Trace b
t2}
mhTrans :: MonadSample m => Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
mhTrans :: Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
mhTrans m :: Weighted (FreeSampler m) a
m t :: Trace a
t@Trace {variables :: forall a. Trace a -> [Double]
variables = [Double]
us, density :: forall a. Trace a -> Log Double
density = Log Double
p} = do
let n :: Int
n = [Double] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
us
[Double]
us' <- do
Int
i <- DiscreteUniform -> m Int
forall d (m :: * -> *).
(DiscreteDistr d, MonadSample m) =>
d -> m Int
discrete (DiscreteUniform -> m Int) -> DiscreteUniform -> m Int
forall a b. (a -> b) -> a -> b
$ Int -> Int -> DiscreteUniform
discreteUniformAB 0 (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
-1)
Double
u' <- m Double
forall (m :: * -> *). MonadSample m => m Double
random
let (xs :: [Double]
xs, _ : ys :: [Double]
ys) = Int -> [Double] -> ([Double], [Double])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
i [Double]
us
[Double] -> m [Double]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Double] -> m [Double]) -> [Double] -> m [Double]
forall a b. (a -> b) -> a -> b
$ [Double]
xs [Double] -> [Double] -> [Double]
forall a. [a] -> [a] -> [a]
++ (Double
u' Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: [Double]
ys)
((b :: a
b, q :: Log Double
q), vs :: [Double]
vs) <- WriterT [Double] m (a, Log Double) -> m ((a, Log Double), [Double])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Double] m (a, Log Double)
-> m ((a, Log Double), [Double]))
-> WriterT [Double] m (a, Log Double)
-> m ((a, Log Double), [Double])
forall a b. (a -> b) -> a -> b
$ Weighted (WriterT [Double] m) a
-> WriterT [Double] m (a, Log Double)
forall (m :: * -> *) a.
Functor m =>
Weighted m a -> m (a, Log Double)
runWeighted (Weighted (WriterT [Double] m) a
-> WriterT [Double] m (a, Log Double))
-> Weighted (WriterT [Double] m) a
-> WriterT [Double] m (a, Log Double)
forall a b. (a -> b) -> a -> b
$ (forall x. FreeSampler m x -> WriterT [Double] m x)
-> Weighted (FreeSampler m) a -> Weighted (WriterT [Double] m) a
forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> Weighted m a -> Weighted n a
Weighted.hoist (m (x, [Double]) -> WriterT [Double] m x
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterT (m (x, [Double]) -> WriterT [Double] m x)
-> (FreeSampler m x -> m (x, [Double]))
-> FreeSampler m x
-> WriterT [Double] m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Double] -> FreeSampler m x -> m (x, [Double])
forall (m :: * -> *) a.
MonadSample m =>
[Double] -> FreeSampler m a -> m (a, [Double])
withPartialRandomness [Double]
us') Weighted (FreeSampler m) a
m
let ratio :: Double
ratio = (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln) (Log Double -> Double) -> Log Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Log Double -> Log Double
forall a. Ord a => a -> a -> a
min 1 (Log Double
q Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ (Log Double
p Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Double] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
vs)))
Bool
accept <- Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli Double
ratio
Trace a -> m (Trace a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Trace a -> m (Trace a)) -> Trace a -> m (Trace a)
forall a b. (a -> b) -> a -> b
$ if Bool
accept then [Double] -> a -> Log Double -> Trace a
forall a. [Double] -> a -> Log Double -> Trace a
Trace [Double]
vs a
b Log Double
q else Trace a
t
mhTrans' :: MonadSample m => Weighted (FreeSampler Identity) a -> Trace a -> m (Trace a)
mhTrans' :: Weighted (FreeSampler Identity) a -> Trace a -> m (Trace a)
mhTrans' m :: Weighted (FreeSampler Identity) a
m = Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
forall (m :: * -> *) a.
MonadSample m =>
Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
mhTrans ((forall x. FreeSampler Identity x -> FreeSampler m x)
-> Weighted (FreeSampler Identity) a -> Weighted (FreeSampler m) a
forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> Weighted m a -> Weighted n a
Weighted.hoist ((forall x. Identity x -> m x)
-> FreeSampler Identity x -> FreeSampler m x
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> FreeSampler m a -> FreeSampler n a
FreeSampler.hoist (x -> m x
forall (m :: * -> *) a. Monad m => a -> m a
return (x -> m x) -> (Identity x -> x) -> Identity x -> m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity x -> x
forall a. Identity a -> a
runIdentity)) Weighted (FreeSampler Identity) a
m)