{-|
Module      : Control.Monad.Bayes.Sequential
Description : Suspendable probabilistic computation
Copyright   : (c) Adam Scibior, 2015-2020
License     : MIT
Maintainer  : leonhard.markert@tweag.io
Stability   : experimental
Portability : GHC

'Sequential' represents a computation that can be suspended.
-}

module Control.Monad.Bayes.Sequential (
    Sequential,
    suspend,
    finish,
    advance,
    finished,
    hoistFirst,
    hoist,
    sis
                ) where

import Control.Monad.Trans
import Control.Monad.Coroutine hiding (suspend)
import Control.Monad.Coroutine.SuspensionFunctors
import Data.Either

import Control.Monad.Bayes.Class

-- | Represents a computation that can be suspended at certain points.
-- The intermediate monadic effects can be extracted, which is particularly
-- useful for implementation of Sequential Monte Carlo related methods.
-- All the probabilistic effects are lifted from the transformed monad, but
-- also `suspend` is inserted after each `factor`.
newtype Sequential m a = Sequential {runSequential :: Coroutine (Await ()) m a}
  deriving(Functor,Applicative,Monad,MonadTrans,MonadIO)

extract :: Await () a -> a
extract (Await f) = f ()

instance MonadSample m => MonadSample (Sequential m) where
  random = lift random
  bernoulli = lift . bernoulli
  categorical = lift . categorical

-- | Execution is 'suspend'ed after each 'score'.
instance MonadCond m => MonadCond (Sequential m) where
  score w = lift (score w) >> suspend

instance MonadInfer m => MonadInfer (Sequential m)

-- | A point where the computation is paused.
suspend :: Monad m => Sequential m ()
suspend = Sequential await

-- | Remove the remaining suspension points.
finish :: Monad m => Sequential m a -> m a
finish = pogoStick extract . runSequential

-- | Execute to the next suspension point.
-- If the computation is finished, do nothing.
--
-- > finish = finish . advance
advance :: Monad m => Sequential m a -> Sequential m a
advance = Sequential . bounce extract . runSequential

-- | Return True if no more suspension points remain.
finished :: Monad m => Sequential m a -> m Bool
finished = fmap isRight . resume . runSequential

-- | Transform the inner monad.
-- This operation only applies to computation up to the first suspension.
hoistFirst :: (forall x. m x -> m x) -> Sequential m a -> Sequential m a
hoistFirst f = Sequential . Coroutine . f . resume . runSequential

-- | Transform the inner monad.
-- The transformation is applied recursively through all the suspension points.
hoist :: (Monad m, Monad n) =>
            (forall x. m x -> n x) -> Sequential m a -> Sequential n a
hoist f = Sequential . mapMonad f . runSequential

-- | Apply a function a given number of times.
composeCopies :: Int -> (a -> a) -> (a -> a)
composeCopies k f = foldr (.) id (replicate k f)

-- | Sequential importance sampling.
-- Applies a given transformation after each time step.
sis :: Monad m
    => (forall x. m x -> m x) -- ^ transformation
    -> Int -- ^ number of time steps
    -> Sequential m a
    -> m a
sis f k = finish . composeCopies k (advance . hoistFirst f)