{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}

-- |
-- Module      : Control.Monad.Bayes.Traced.Static
-- Description : Distributions on execution traces of full programs
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
module Control.Monad.Bayes.Traced.Static
  ( Traced (..),
    hoist,
    marginal,
    mhStep,
    mh,
  )
where

import Control.Applicative (liftA2)
import Control.Monad.Bayes.Class
  ( MonadDistribution (random),
    MonadFactor (..),
    MonadMeasure,
  )
import Control.Monad.Bayes.Density.Free (Density)
import Control.Monad.Bayes.Traced.Common
  ( Trace (..),
    bind,
    mhTransFree,
    scored,
    singleton,
  )
import Control.Monad.Bayes.Weighted (Weighted)
import Control.Monad.Trans (MonadTrans (..))
import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList)

-- | A tracing monad where only a subset of random choices are traced.
--
-- The random choices that are not to be traced should be lifted from the
-- transformed monad.
data Traced m a = Traced
  { forall (m :: * -> *) a. Traced m a -> Weighted (Density m) a
model :: Weighted (Density m) a,
    forall (m :: * -> *) a. Traced m a -> m (Trace a)
traceDist :: m (Trace a)
  }

instance Monad m => Functor (Traced m) where
  fmap :: forall a b. (a -> b) -> Traced m a -> Traced m b
fmap a -> b
f (Traced Weighted (Density m) a
m m (Trace a)
d) = forall (m :: * -> *) a.
Weighted (Density m) a -> m (Trace a) -> Traced m a
Traced (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Weighted (Density m) a
m) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) m (Trace a)
d)

instance Monad m => Applicative (Traced m) where
  pure :: forall a. a -> Traced m a
pure a
x = forall (m :: * -> *) a.
Weighted (Density m) a -> m (Trace a) -> Traced m a
Traced (forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x) (forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x))
  (Traced Weighted (Density m) (a -> b)
mf m (Trace (a -> b))
df) <*> :: forall a b. Traced m (a -> b) -> Traced m a -> Traced m b
<*> (Traced Weighted (Density m) a
mx m (Trace a)
dx) = forall (m :: * -> *) a.
Weighted (Density m) a -> m (Trace a) -> Traced m a
Traced (Weighted (Density m) (a -> b)
mf forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Weighted (Density m) a
mx) (forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
(<*>) m (Trace (a -> b))
df m (Trace a)
dx)

instance Monad m => Monad (Traced m) where
  (Traced Weighted (Density m) a
mx m (Trace a)
dx) >>= :: forall a b. Traced m a -> (a -> Traced m b) -> Traced m b
>>= a -> Traced m b
f = forall (m :: * -> *) a.
Weighted (Density m) a -> m (Trace a) -> Traced m a
Traced Weighted (Density m) b
my m (Trace b)
dy
    where
      my :: Weighted (Density m) b
my = Weighted (Density m) a
mx forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. Traced m a -> Weighted (Density m) a
model forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Traced m b
f
      dy :: m (Trace b)
dy = m (Trace a)
dx forall (m :: * -> *) a b.
Monad m =>
m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
`bind` (forall (m :: * -> *) a. Traced m a -> m (Trace a)
traceDist forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Traced m b
f)

instance MonadTrans Traced where
  lift :: forall (m :: * -> *) a. Monad m => m a -> Traced m a
lift m a
m = forall (m :: * -> *) a.
Weighted (Density m) a -> m (Trace a) -> Traced m a
Traced (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m a
m) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (f :: * -> *) a. Applicative f => a -> f a
pure m a
m)

instance MonadDistribution m => MonadDistribution (Traced m) where
  random :: Traced m Double
random = forall (m :: * -> *) a.
Weighted (Density m) a -> m (Trace a) -> Traced m a
Traced forall (m :: * -> *). MonadDistribution m => m Double
random (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Double -> Trace Double
singleton forall (m :: * -> *). MonadDistribution m => m Double
random)

instance MonadFactor m => MonadFactor (Traced m) where
  score :: Log Double -> Traced m ()
score Log Double
w = forall (m :: * -> *) a.
Weighted (Density m) a -> m (Trace a) -> Traced m a
Traced (forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score Log Double
w) (forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score Log Double
w forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Log Double -> Trace ()
scored Log Double
w))

instance MonadMeasure m => MonadMeasure (Traced m)

hoist :: (forall x. m x -> m x) -> Traced m a -> Traced m a
hoist :: forall (m :: * -> *) a.
(forall x. m x -> m x) -> Traced m a -> Traced m a
hoist forall x. m x -> m x
f (Traced Weighted (Density m) a
m m (Trace a)
d) = forall (m :: * -> *) a.
Weighted (Density m) a -> m (Trace a) -> Traced m a
Traced Weighted (Density m) a
m (forall x. m x -> m x
f m (Trace a)
d)

-- | Discard the trace and supporting infrastructure.
marginal :: Monad m => Traced m a -> m a
marginal :: forall (m :: * -> *) a. Monad m => Traced m a -> m a
marginal (Traced Weighted (Density m) a
_ m (Trace a)
d) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Trace a -> a
output m (Trace a)
d

-- | A single step of the Trace Metropolis-Hastings algorithm.
mhStep :: MonadDistribution m => Traced m a -> Traced m a
mhStep :: forall (m :: * -> *) a.
MonadDistribution m =>
Traced m a -> Traced m a
mhStep (Traced Weighted (Density m) a
m m (Trace a)
d) = forall (m :: * -> *) a.
Weighted (Density m) a -> m (Trace a) -> Traced m a
Traced Weighted (Density m) a
m m (Trace a)
d'
  where
    d' :: m (Trace a)
d' = m (Trace a)
d forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (Trace a)
mhTransFree Weighted (Density m) a
m

-- $setup
-- >>> import Control.Monad.Bayes.Class
-- >>> import Control.Monad.Bayes.Sampler.Strict
-- >>> import Control.Monad.Bayes.Weighted

-- | Full run of the Trace Metropolis-Hastings algorithm with a specified
-- number of steps. Newest samples are at the head of the list.
--
-- For example:
--
-- * I have forgotten what day it is.
-- * There are ten buses per hour in the week and three buses per hour at the weekend.
-- * I observe four buses in a given hour.
-- * What is the probability that it is the weekend?
--
-- >>> :{
--  let
--    bus = do x <- bernoulli (2/7)
--             let rate = if x then 3 else 10
--             factor $ poissonPdf rate 4
--             return x
--    mhRunBusSingleObs = do
--      let nSamples = 2
--      sampleIOfixed $ unweighted $ mh nSamples bus
--  in mhRunBusSingleObs
-- :}
-- [True,True,True]
--
-- Of course, it will need to be run more than twice to get a reasonable estimate.
mh :: MonadDistribution m => Int -> Traced m a -> m [a]
mh :: forall (m :: * -> *) a.
MonadDistribution m =>
Int -> Traced m a -> m [a]
mh Int
n (Traced Weighted (Density m) a
m m (Trace a)
d) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (a -> b) -> [a] -> [b]
map forall a. Trace a -> a
output forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. NonEmpty a -> [a]
NE.toList) (forall {t}. (Ord t, Num t) => t -> m (NonEmpty (Trace a))
f Int
n)
  where
    f :: t -> m (NonEmpty (Trace a))
f t
k
      | t
k forall a. Ord a => a -> a -> Bool
<= t
0 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. a -> [a] -> NonEmpty a
:| []) m (Trace a)
d
      | Bool
otherwise = do
          (Trace a
x :| [Trace a]
xs) <- t -> m (NonEmpty (Trace a))
f (t
k forall a. Num a => a -> a -> a
- t
1)
          Trace a
y <- forall (m :: * -> *) a.
MonadDistribution m =>
Weighted (Density m) a -> Trace a -> m (Trace a)
mhTransFree Weighted (Density m) a
m Trace a
x
          forall (m :: * -> *) a. Monad m => a -> m a
return (Trace a
y forall a. a -> [a] -> NonEmpty a
:| Trace a
x forall a. a -> [a] -> [a]
: [Trace a]
xs)