module Control.Monad.Bayes.Traced.Common
( Trace (..),
singleton,
scored,
bind,
mhTrans,
mhTransWithBool,
mhTransFree,
mhTrans',
burnIn,
MHResult (..),
)
where
import Control.Monad.Bayes.Class
( MonadDistribution (bernoulli, random),
discrete,
)
import Control.Monad.Bayes.Density.Free qualified as Free
import Control.Monad.Bayes.Density.State qualified as State
import Control.Monad.Bayes.Weighted as Weighted
( Weighted,
hoist,
weighted,
)
import Control.Monad.Writer (WriterT (WriterT, runWriterT))
import Data.Functor.Identity (Identity (runIdentity))
import Numeric.Log (Log, ln)
import Statistics.Distribution.DiscreteUniform (discreteUniformAB)
data MHResult a = MHResult
{ forall a. MHResult a -> Bool
success :: Bool,
forall a. MHResult a -> Trace a
trace :: Trace a
}
data Trace a = Trace
{
forall a. Trace a -> [Double]
variables :: [Double],
forall a. Trace a -> a
output :: a,
forall a. Trace a -> Log Double
probDensity :: Log Double
}
instance Functor Trace where
fmap :: forall a b. (a -> b) -> Trace a -> Trace b
fmap a -> b
f Trace a
t = Trace a
t {output :: b
output = a -> b
f (forall a. Trace a -> a
output Trace a
t)}
instance Applicative Trace where
pure :: forall a. a -> Trace a
pure a
x = Trace {variables :: [Double]
variables = [], output :: a
output = a
x, probDensity :: Log Double
probDensity = Log Double
1}
Trace (a -> b)
tf <*> :: forall a b. Trace (a -> b) -> Trace a -> Trace b
<*> Trace a
tx =
Trace
{ variables :: [Double]
variables = forall a. Trace a -> [Double]
variables Trace (a -> b)
tf forall a. [a] -> [a] -> [a]
++ forall a. Trace a -> [Double]
variables Trace a
tx,
output :: b
output = forall a. Trace a -> a
output Trace (a -> b)
tf (forall a. Trace a -> a
output Trace a
tx),
probDensity :: Log Double
probDensity = forall a. Trace a -> Log Double
probDensity Trace (a -> b)
tf forall a. Num a => a -> a -> a
* forall a. Trace a -> Log Double
probDensity Trace a
tx
}
instance Monad Trace where
Trace a
t >>= :: forall a b. Trace a -> (a -> Trace b) -> Trace b
>>= a -> Trace b
f =
let t' :: Trace b
t' = a -> Trace b
f (forall a. Trace a -> a
output Trace a
t)
in Trace b
t' {variables :: [Double]
variables = forall a. Trace a -> [Double]
variables Trace a
t forall a. [a] -> [a] -> [a]
++ forall a. Trace a -> [Double]
variables Trace b
t', probDensity :: Log Double
probDensity = forall a. Trace a -> Log Double
probDensity Trace a
t forall a. Num a => a -> a -> a
* forall a. Trace a -> Log Double
probDensity Trace b
t'}
singleton :: Double -> Trace Double
singleton :: Double -> Trace Double
singleton Double
u = Trace {variables :: [Double]
variables = [Double
u], output :: Double
output = Double
u, probDensity :: Log Double
probDensity = Log Double
1}
scored :: Log Double -> Trace ()
scored :: Log Double -> Trace ()
scored Log Double
w = Trace {variables :: [Double]
variables = [], output :: ()
output = (), probDensity :: Log Double
probDensity = Log Double
w}
bind :: Monad m => m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
bind :: forall (m :: * -> *) a b.
Monad m =>
m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
bind m (Trace a)
dx a -> m (Trace b)
f = do
Trace a
t1 <- m (Trace a)
dx
Trace b
t2 <- a -> m (Trace b)
f (forall a. Trace a -> a
output Trace a
t1)
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Trace b
t2 {variables :: [Double]
variables = forall a. Trace a -> [Double]
variables Trace a
t1 forall a. [a] -> [a] -> [a]
++ forall a. Trace a -> [Double]
variables Trace b
t2, probDensity :: Log Double
probDensity = forall a. Trace a -> Log Double
probDensity Trace a
t1 forall a. Num a => a -> a -> a
* forall a. Trace a -> Log Double
probDensity Trace b
t2}
mhTrans :: MonadDistribution m => (Weighted (State.Density m)) a -> Trace a -> m (Trace a)
mhTrans :: forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (Trace a)
mhTrans Weighted (Density m) a
m t :: Trace a
t@Trace {variables :: forall a. Trace a -> [Double]
variables = [Double]
us, probDensity :: forall a. Trace a -> Log Double
probDensity = Log Double
p} = do
let n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
us
[Double]
us' <- do
Int
i <- forall d (m :: * -> *).
(DiscreteDistr d, MonadDistribution m) =>
d -> m Int
discrete forall a b. (a -> b) -> a -> b
$ Int -> Int -> DiscreteUniform
discreteUniformAB Int
0 (Int
n forall a. Num a => a -> a -> a
- Int
1)
Double
u' <- forall (m :: * -> *). MonadDistribution m => m Double
random
case forall a. Int -> [a] -> ([a], [a])
splitAt Int
i [Double]
us of
([Double]
xs, Double
_ : [Double]
ys) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Double]
xs forall a. [a] -> [a] -> [a]
++ (Double
u' forall a. a -> [a] -> [a]
: [Double]
ys)
([Double], [Double])
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
((a
b, Log Double
q), [Double]
vs) <- forall (m :: * -> *) b.
Monad m =>
Density m b -> [Double] -> m (b, [Double])
State.density (forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted Weighted (Density m) a
m) [Double]
us'
let ratio :: Double
ratio = (forall a. Floating a => a -> a
exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Log a -> a
ln) forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a
min Log Double
1 (Log Double
q forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n forall a. Fractional a => a -> a -> a
/ (Log Double
p forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
vs)))
Bool
accept <- forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli Double
ratio
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if Bool
accept then forall a. [Double] -> a -> Log Double -> Trace a
Trace [Double]
vs a
b Log Double
q else Trace a
t
mhTransFree :: MonadDistribution m => Weighted (Free.Density m) a -> Trace a -> m (Trace a)
mhTransFree :: forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (Trace a)
mhTransFree Weighted (Density m) a
m Trace a
t = forall a. MHResult a -> Trace a
trace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (MHResult a)
mhTransWithBool Weighted (Density m) a
m Trace a
t
mhTransWithBool :: MonadDistribution m => Weighted (Free.Density m) a -> Trace a -> m (MHResult a)
mhTransWithBool :: forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (MHResult a)
mhTransWithBool Weighted (Density m) a
m t :: Trace a
t@Trace {variables :: forall a. Trace a -> [Double]
variables = [Double]
us, probDensity :: forall a. Trace a -> Log Double
probDensity = Log Double
p} = do
let n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
us
[Double]
us' <- do
Int
i <- forall d (m :: * -> *).
(DiscreteDistr d, MonadDistribution m) =>
d -> m Int
discrete forall a b. (a -> b) -> a -> b
$ Int -> Int -> DiscreteUniform
discreteUniformAB Int
0 (Int
n forall a. Num a => a -> a -> a
- Int
1)
Double
u' <- forall (m :: * -> *). MonadDistribution m => m Double
random
case forall a. Int -> [a] -> ([a], [a])
splitAt Int
i [Double]
us of
([Double]
xs, Double
_ : [Double]
ys) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Double]
xs forall a. [a] -> [a] -> [a]
++ (Double
u' forall a. a -> [a] -> [a]
: [Double]
ys)
([Double], [Double])
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
((a
b, Log Double
q), [Double]
vs) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> Weighted m a -> Weighted n a
Weighted.hoist (forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> Density m a -> m (a, [Double])
Free.density [Double]
us') Weighted (Density m) a
m
let ratio :: Double
ratio = (forall a. Floating a => a -> a
exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Log a -> a
ln) forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a
min Log Double
1 (Log Double
q forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n forall a. Fractional a => a -> a -> a
/ (Log Double
p forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
vs)))
Bool
accept <- forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli Double
ratio
forall (m :: * -> *) a. Monad m => a -> m a
return if Bool
accept then forall a. Bool -> Trace a -> MHResult a
MHResult Bool
True (forall a. [Double] -> a -> Log Double -> Trace a
Trace [Double]
vs a
b Log Double
q) else forall a. Bool -> Trace a -> MHResult a
MHResult Bool
False Trace a
t
mhTrans' :: MonadDistribution m => Weighted (Free.Density Identity) a -> Trace a -> m (Trace a)
mhTrans' :: forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density Identity) a -> Trace a -> m (Trace a)
mhTrans' Weighted (Density Identity) a
m = forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (Trace a)
mhTransFree (forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> Weighted m a -> Weighted n a
Weighted.hoist (forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> Density m a -> Density n a
Free.hoist (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
runIdentity)) Weighted (Density Identity) a
m)
burnIn :: Functor m => Int -> m [a] -> m [a]
burnIn :: forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
n = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {a}. [a] -> [a]
dropEnd
where
dropEnd :: [a] -> [a]
dropEnd [a]
ls = let len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ls in forall a. Int -> [a] -> [a]
take (Int
len forall a. Num a => a -> a -> a
- Int
n) [a]
ls