-- |
-- Module      : Control.Monad.Bayes.Traced.Dynamic
-- Description : Distributions on execution traces that can be dynamically frozen
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
module Control.Monad.Bayes.Traced.Dynamic
  ( Traced,
    hoistT,
    marginal,
    freeze,
    mhStep,
    mh,
  )
where

import Control.Monad (join)
import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Free (FreeSampler)
import Control.Monad.Bayes.Traced.Common
import Control.Monad.Bayes.Weighted (Weighted)
import Control.Monad.Trans (MonadTrans (..))

-- | A tracing monad where only a subset of random choices are traced and this
-- subset can be adjusted dynamically.
newtype Traced m a = Traced {Traced m a -> m (Weighted (FreeSampler m) a, Trace a)
runTraced :: m (Weighted (FreeSampler m) a, Trace a)}

pushM :: Monad m => m (Weighted (FreeSampler m) a) -> Weighted (FreeSampler m) a
pushM :: m (Weighted (FreeSampler m) a) -> Weighted (FreeSampler m) a
pushM = Weighted (FreeSampler m) (Weighted (FreeSampler m) a)
-> Weighted (FreeSampler m) a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Weighted (FreeSampler m) (Weighted (FreeSampler m) a)
 -> Weighted (FreeSampler m) a)
-> (m (Weighted (FreeSampler m) a)
    -> Weighted (FreeSampler m) (Weighted (FreeSampler m) a))
-> m (Weighted (FreeSampler m) a)
-> Weighted (FreeSampler m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FreeSampler m (Weighted (FreeSampler m) a)
-> Weighted (FreeSampler m) (Weighted (FreeSampler m) a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (FreeSampler m (Weighted (FreeSampler m) a)
 -> Weighted (FreeSampler m) (Weighted (FreeSampler m) a))
-> (m (Weighted (FreeSampler m) a)
    -> FreeSampler m (Weighted (FreeSampler m) a))
-> m (Weighted (FreeSampler m) a)
-> Weighted (FreeSampler m) (Weighted (FreeSampler m) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Weighted (FreeSampler m) a)
-> FreeSampler m (Weighted (FreeSampler m) a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance Monad m => Functor (Traced m) where
  fmap :: (a -> b) -> Traced m a -> Traced m b
fmap f :: a -> b
f (Traced c :: m (Weighted (FreeSampler m) a, Trace a)
c) = m (Weighted (FreeSampler m) b, Trace b) -> Traced m b
forall (m :: * -> *) a.
m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
Traced (m (Weighted (FreeSampler m) b, Trace b) -> Traced m b)
-> m (Weighted (FreeSampler m) b, Trace b) -> Traced m b
forall a b. (a -> b) -> a -> b
$ do
    (m :: Weighted (FreeSampler m) a
m, t :: Trace a
t) <- m (Weighted (FreeSampler m) a, Trace a)
c
    let m' :: Weighted (FreeSampler m) b
m' = (a -> b)
-> Weighted (FreeSampler m) a -> Weighted (FreeSampler m) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Weighted (FreeSampler m) a
m
    let t' :: Trace b
t' = (a -> b) -> Trace a -> Trace b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Trace a
t
    (Weighted (FreeSampler m) b, Trace b)
-> m (Weighted (FreeSampler m) b, Trace b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Weighted (FreeSampler m) b
m', Trace b
t')

instance Monad m => Applicative (Traced m) where
  pure :: a -> Traced m a
pure x :: a
x = m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
forall (m :: * -> *) a.
m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
Traced (m (Weighted (FreeSampler m) a, Trace a) -> Traced m a)
-> m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
forall a b. (a -> b) -> a -> b
$ (Weighted (FreeSampler m) a, Trace a)
-> m (Weighted (FreeSampler m) a, Trace a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Weighted (FreeSampler m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x, a -> Trace a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x)
  (Traced cf :: m (Weighted (FreeSampler m) (a -> b), Trace (a -> b))
cf) <*> :: Traced m (a -> b) -> Traced m a -> Traced m b
<*> (Traced cx :: m (Weighted (FreeSampler m) a, Trace a)
cx) = m (Weighted (FreeSampler m) b, Trace b) -> Traced m b
forall (m :: * -> *) a.
m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
Traced (m (Weighted (FreeSampler m) b, Trace b) -> Traced m b)
-> m (Weighted (FreeSampler m) b, Trace b) -> Traced m b
forall a b. (a -> b) -> a -> b
$ do
    (mf :: Weighted (FreeSampler m) (a -> b)
mf, tf :: Trace (a -> b)
tf) <- m (Weighted (FreeSampler m) (a -> b), Trace (a -> b))
cf
    (mx :: Weighted (FreeSampler m) a
mx, tx :: Trace a
tx) <- m (Weighted (FreeSampler m) a, Trace a)
cx
    (Weighted (FreeSampler m) b, Trace b)
-> m (Weighted (FreeSampler m) b, Trace b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Weighted (FreeSampler m) (a -> b)
mf Weighted (FreeSampler m) (a -> b)
-> Weighted (FreeSampler m) a -> Weighted (FreeSampler m) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Weighted (FreeSampler m) a
mx, Trace (a -> b)
tf Trace (a -> b) -> Trace a -> Trace b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Trace a
tx)

instance Monad m => Monad (Traced m) where
  (Traced cx :: m (Weighted (FreeSampler m) a, Trace a)
cx) >>= :: Traced m a -> (a -> Traced m b) -> Traced m b
>>= f :: a -> Traced m b
f = m (Weighted (FreeSampler m) b, Trace b) -> Traced m b
forall (m :: * -> *) a.
m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
Traced (m (Weighted (FreeSampler m) b, Trace b) -> Traced m b)
-> m (Weighted (FreeSampler m) b, Trace b) -> Traced m b
forall a b. (a -> b) -> a -> b
$ do
    (mx :: Weighted (FreeSampler m) a
mx, tx :: Trace a
tx) <- m (Weighted (FreeSampler m) a, Trace a)
cx
    let m :: Weighted (FreeSampler m) b
m = Weighted (FreeSampler m) a
mx Weighted (FreeSampler m) a
-> (a -> Weighted (FreeSampler m) b) -> Weighted (FreeSampler m) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m (Weighted (FreeSampler m) b) -> Weighted (FreeSampler m) b
forall (m :: * -> *) a.
Monad m =>
m (Weighted (FreeSampler m) a) -> Weighted (FreeSampler m) a
pushM (m (Weighted (FreeSampler m) b) -> Weighted (FreeSampler m) b)
-> (a -> m (Weighted (FreeSampler m) b))
-> a
-> Weighted (FreeSampler m) b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Weighted (FreeSampler m) b, Trace b)
 -> Weighted (FreeSampler m) b)
-> m (Weighted (FreeSampler m) b, Trace b)
-> m (Weighted (FreeSampler m) b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Weighted (FreeSampler m) b, Trace b) -> Weighted (FreeSampler m) b
forall a b. (a, b) -> a
fst (m (Weighted (FreeSampler m) b, Trace b)
 -> m (Weighted (FreeSampler m) b))
-> (a -> m (Weighted (FreeSampler m) b, Trace b))
-> a
-> m (Weighted (FreeSampler m) b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Traced m b -> m (Weighted (FreeSampler m) b, Trace b)
forall (m :: * -> *) a.
Traced m a -> m (Weighted (FreeSampler m) a, Trace a)
runTraced (Traced m b -> m (Weighted (FreeSampler m) b, Trace b))
-> (a -> Traced m b)
-> a
-> m (Weighted (FreeSampler m) b, Trace b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Traced m b
f
    Trace b
t <- Trace a -> m (Trace a)
forall (m :: * -> *) a. Monad m => a -> m a
return Trace a
tx m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
forall (m :: * -> *) a b.
Monad m =>
m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
`bind` (((Weighted (FreeSampler m) b, Trace b) -> Trace b)
-> m (Weighted (FreeSampler m) b, Trace b) -> m (Trace b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Weighted (FreeSampler m) b, Trace b) -> Trace b
forall a b. (a, b) -> b
snd (m (Weighted (FreeSampler m) b, Trace b) -> m (Trace b))
-> (a -> m (Weighted (FreeSampler m) b, Trace b))
-> a
-> m (Trace b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Traced m b -> m (Weighted (FreeSampler m) b, Trace b)
forall (m :: * -> *) a.
Traced m a -> m (Weighted (FreeSampler m) a, Trace a)
runTraced (Traced m b -> m (Weighted (FreeSampler m) b, Trace b))
-> (a -> Traced m b)
-> a
-> m (Weighted (FreeSampler m) b, Trace b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Traced m b
f)
    (Weighted (FreeSampler m) b, Trace b)
-> m (Weighted (FreeSampler m) b, Trace b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Weighted (FreeSampler m) b
m, Trace b
t)

instance MonadTrans Traced where
  lift :: m a -> Traced m a
lift m :: m a
m = m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
forall (m :: * -> *) a.
m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
Traced (m (Weighted (FreeSampler m) a, Trace a) -> Traced m a)
-> m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
forall a b. (a -> b) -> a -> b
$ (a -> (Weighted (FreeSampler m) a, Trace a))
-> m a -> m (Weighted (FreeSampler m) a, Trace a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((,) (FreeSampler m a -> Weighted (FreeSampler m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (FreeSampler m a -> Weighted (FreeSampler m) a)
-> FreeSampler m a -> Weighted (FreeSampler m) a
forall a b. (a -> b) -> a -> b
$ m a -> FreeSampler m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m a
m) (Trace a -> (Weighted (FreeSampler m) a, Trace a))
-> (a -> Trace a) -> a -> (Weighted (FreeSampler m) a, Trace a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Trace a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) m a
m

instance MonadSample m => MonadSample (Traced m) where
  random :: Traced m Double
random = m (Weighted (FreeSampler m) Double, Trace Double)
-> Traced m Double
forall (m :: * -> *) a.
m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
Traced (m (Weighted (FreeSampler m) Double, Trace Double)
 -> Traced m Double)
-> m (Weighted (FreeSampler m) Double, Trace Double)
-> Traced m Double
forall a b. (a -> b) -> a -> b
$ (Double -> (Weighted (FreeSampler m) Double, Trace Double))
-> m Double -> m (Weighted (FreeSampler m) Double, Trace Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((,) Weighted (FreeSampler m) Double
forall (m :: * -> *). MonadSample m => m Double
random (Trace Double -> (Weighted (FreeSampler m) Double, Trace Double))
-> (Double -> Trace Double)
-> Double
-> (Weighted (FreeSampler m) Double, Trace Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Trace Double
singleton) m Double
forall (m :: * -> *). MonadSample m => m Double
random

instance MonadCond m => MonadCond (Traced m) where
  score :: Log Double -> Traced m ()
score w :: Log Double
w = m (Weighted (FreeSampler m) (), Trace ()) -> Traced m ()
forall (m :: * -> *) a.
m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
Traced (m (Weighted (FreeSampler m) (), Trace ()) -> Traced m ())
-> m (Weighted (FreeSampler m) (), Trace ()) -> Traced m ()
forall a b. (a -> b) -> a -> b
$ (Trace () -> (Weighted (FreeSampler m) (), Trace ()))
-> m (Trace ()) -> m (Weighted (FreeSampler m) (), Trace ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Log Double -> Weighted (FreeSampler m) ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score Log Double
w,) (Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score Log Double
w m () -> m (Trace ()) -> m (Trace ())
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Trace () -> m (Trace ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Log Double -> Trace ()
scored Log Double
w))

instance MonadInfer m => MonadInfer (Traced m)

hoistT :: (forall x. m x -> m x) -> Traced m a -> Traced m a
hoistT :: (forall x. m x -> m x) -> Traced m a -> Traced m a
hoistT f :: forall x. m x -> m x
f (Traced c :: m (Weighted (FreeSampler m) a, Trace a)
c) = m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
forall (m :: * -> *) a.
m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
Traced (m (Weighted (FreeSampler m) a, Trace a)
-> m (Weighted (FreeSampler m) a, Trace a)
forall x. m x -> m x
f m (Weighted (FreeSampler m) a, Trace a)
c)

-- | Discard the trace and supporting infrastructure.
marginal :: Monad m => Traced m a -> m a
marginal :: Traced m a -> m a
marginal (Traced c :: m (Weighted (FreeSampler m) a, Trace a)
c) = ((Weighted (FreeSampler m) a, Trace a) -> a)
-> m (Weighted (FreeSampler m) a, Trace a) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Trace a -> a
forall a. Trace a -> a
output (Trace a -> a)
-> ((Weighted (FreeSampler m) a, Trace a) -> Trace a)
-> (Weighted (FreeSampler m) a, Trace a)
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Weighted (FreeSampler m) a, Trace a) -> Trace a
forall a b. (a, b) -> b
snd) m (Weighted (FreeSampler m) a, Trace a)
c

-- | Freeze all traced random choices to their current values and stop tracing
-- them.
freeze :: Monad m => Traced m a -> Traced m a
freeze :: Traced m a -> Traced m a
freeze (Traced c :: m (Weighted (FreeSampler m) a, Trace a)
c) = m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
forall (m :: * -> *) a.
m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
Traced (m (Weighted (FreeSampler m) a, Trace a) -> Traced m a)
-> m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
forall a b. (a -> b) -> a -> b
$ do
  (_, t :: Trace a
t) <- m (Weighted (FreeSampler m) a, Trace a)
c
  let x :: a
x = Trace a -> a
forall a. Trace a -> a
output Trace a
t
  (Weighted (FreeSampler m) a, Trace a)
-> m (Weighted (FreeSampler m) a, Trace a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Weighted (FreeSampler m) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x, a -> Trace a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x)

-- | A single step of the Trace Metropolis-Hastings algorithm.
mhStep :: MonadSample m => Traced m a -> Traced m a
mhStep :: Traced m a -> Traced m a
mhStep (Traced c :: m (Weighted (FreeSampler m) a, Trace a)
c) = m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
forall (m :: * -> *) a.
m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
Traced (m (Weighted (FreeSampler m) a, Trace a) -> Traced m a)
-> m (Weighted (FreeSampler m) a, Trace a) -> Traced m a
forall a b. (a -> b) -> a -> b
$ do
  (m :: Weighted (FreeSampler m) a
m, t :: Trace a
t) <- m (Weighted (FreeSampler m) a, Trace a)
c
  Trace a
t' <- Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
forall (m :: * -> *) a.
MonadSample m =>
Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
mhTrans Weighted (FreeSampler m) a
m Trace a
t
  (Weighted (FreeSampler m) a, Trace a)
-> m (Weighted (FreeSampler m) a, Trace a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Weighted (FreeSampler m) a
m, Trace a
t')

-- | Full run of the Trace Metropolis-Hastings algorithm with a specified
-- number of steps.
mh :: MonadSample m => Int -> Traced m a -> m [a]
mh :: Int -> Traced m a -> m [a]
mh n :: Int
n (Traced c :: m (Weighted (FreeSampler m) a, Trace a)
c) = do
  (m :: Weighted (FreeSampler m) a
m, t :: Trace a
t) <- m (Weighted (FreeSampler m) a, Trace a)
c
  let f :: Int -> m [Trace a]
f 0 = [Trace a] -> m [Trace a]
forall (m :: * -> *) a. Monad m => a -> m a
return [Trace a
t]
      f k :: Int
k = do
        ~(x :: Trace a
x : xs :: [Trace a]
xs) <- Int -> m [Trace a]
f (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
-1)
        Trace a
y <- Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
forall (m :: * -> *) a.
MonadSample m =>
Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
mhTrans Weighted (FreeSampler m) a
m Trace a
x
        [Trace a] -> m [Trace a]
forall (m :: * -> *) a. Monad m => a -> m a
return (Trace a
y Trace a -> [Trace a] -> [Trace a]
forall a. a -> [a] -> [a]
: Trace a
x Trace a -> [Trace a] -> [Trace a]
forall a. a -> [a] -> [a]
: [Trace a]
xs)
  [Trace a]
ts <- Int -> m [Trace a]
f Int
n
  let xs :: [a]
xs = (Trace a -> a) -> [Trace a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Trace a -> a
forall a. Trace a -> a
output [Trace a]
ts
  [a] -> m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return [a]
xs