{-|
Module      : Control.Monad.Bayes.Traced.Basic
Description : Distributions on full execution traces of full programs
Copyright   : (c) Adam Scibior, 2015-2020
License     : MIT
Maintainer  : leonhard.markert@tweag.io
Stability   : experimental
Portability : GHC

-}

module Control.Monad.Bayes.Traced.Basic (
  Traced,
  hoistT,
  marginal,
  mhStep,
  mh
) where

import Data.Functor.Identity
import Control.Applicative (liftA2)

import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Weighted as Weighted
import Control.Monad.Bayes.Free as FreeSampler

import Control.Monad.Bayes.Traced.Common

-- | Tracing monad that records random choices made in the program.
-- The first component is used to run the program with a modified trace,
-- while the second records a trace and an output value from a run.
data Traced m a = Traced (Weighted (FreeSampler Identity) a) (m (Trace a))

traceDist :: Traced m a -> m (Trace a)
traceDist (Traced _ d) = d

model :: Traced m a -> Weighted (FreeSampler Identity) a
model (Traced m _) = m

instance Monad m => Functor (Traced m) where
  fmap f (Traced m d) = Traced (fmap f m) (fmap (fmap f) d)

instance Monad m => Applicative (Traced m) where
  pure x = Traced (pure x) (pure (pure x))
  (Traced mf df) <*> (Traced mx dx) = Traced (mf <*> mx) (liftA2 (<*>) df dx)

instance Monad m => Monad (Traced m) where
  (Traced mx dx) >>= f = Traced my dy where
    my = mx >>= model . f
    dy = dx `bind` (traceDist . f)

instance MonadSample m => MonadSample (Traced m) where
  random = Traced random (fmap singleton random)

instance MonadCond m => MonadCond (Traced m) where
  score w = Traced (score w) (score w >> pure (scored w))

instance MonadInfer m => MonadInfer (Traced m)

hoistT :: (forall x. m x -> m x) -> Traced m a -> Traced m a
hoistT f (Traced m d) = Traced m (f d)

-- | Discard the trace and supporting infrastructure.
marginal :: Monad m => Traced m a -> m a
marginal (Traced _ d) = fmap output d

-- | A single step of the Trace MH algorithm.
mhStep :: MonadSample m => Traced m a -> Traced m a
mhStep (Traced m d) = Traced m d' where
  d' = d >>= mhTrans' m

-- | Full run of the Trace MH algorithm with a specified
-- number of steps.
mh :: MonadSample m => Int -> Traced m a -> m [a]
mh n (Traced m d) = fmap (map output) t where
  t = f n
  f 0 = fmap (:[]) d
  f k = do
    ~(x:xs) <- f (k-1)
    y <- mhTrans' m x
    return (y:x:xs)